aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-metal
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-metal
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-metal')
-rw-r--r--llama.cpp/ggml/src/ggml-metal/CMakeLists.txt124
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp446
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h52
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h41
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m702
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp1875
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h290
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m1748
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h1051
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp4222
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h93
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp937
-rw-r--r--llama.cpp/ggml/src/ggml-metal/ggml-metal.metal9798
13 files changed, 21379 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt b/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt
new file mode 100644
index 0000000..42054d8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt
@@ -0,0 +1,124 @@
1find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
2find_library(METAL_FRAMEWORK Metal REQUIRED)
3find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
4
5message(STATUS "Metal framework found")
6
7ggml_add_backend_library(ggml-metal
8 ggml-metal.cpp
9 ggml-metal-device.m
10 ggml-metal-device.cpp
11 ggml-metal-common.cpp
12 ggml-metal-context.m
13 ggml-metal-ops.cpp
14 )
15
16target_link_libraries(ggml-metal PRIVATE
17 ${FOUNDATION_LIBRARY}
18 ${METAL_FRAMEWORK}
19 ${METALKIT_FRAMEWORK}
20 )
21
22if (GGML_METAL_NDEBUG)
23 add_compile_definitions(GGML_METAL_NDEBUG)
24endif()
25
26set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
27if (GGML_METAL_EMBED_LIBRARY)
28 enable_language(ASM)
29
30 add_compile_definitions(GGML_METAL_EMBED_LIBRARY)
31
32 set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
33 set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
34
35 file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated")
36
37 # merge ggml-common.h and ggml-metal.metal into a single file
38 set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
39 set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
40 set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
41
42 add_custom_command(
43 OUTPUT "${METALLIB_EMBED_ASM}"
44 COMMAND echo "Embedding Metal library"
45 COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}"
46 COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}"
47 COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}"
48 COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}"
49 COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}"
50 COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}"
51 COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}"
52 COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}"
53 DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
54 COMMENT "Generate assembly for embedded Metal library"
55 VERBATIM
56 )
57
58 target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}")
59else()
60 # copy metal files to bin directory
61 configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
62 configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
63 configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
64
65 if (GGML_METAL_SHADER_DEBUG)
66 # custom command to do the following:
67 # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
68 # xcrun -sdk macosx metallib ggml-metal.air -o default.metallib
69 #
70 # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works
71 # disabling fast math is needed in order to pass tests/test-backend-ops
72 # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1
73 # note: unfortunately, we have to call it default.metallib instead of ggml.metallib
74 # ref: https://github.com/ggml-org/whisper.cpp/issues/1720
75 # note: adding -g causes segmentation fault during compile
76 #set(XC_FLAGS -fno-fast-math -fno-inline -g)
77 set(XC_FLAGS -fno-fast-math -fno-inline)
78 else()
79 set(XC_FLAGS -O3)
80 endif()
81
82 # Append macOS metal versioning flags
83 if (GGML_METAL_MACOSX_VERSION_MIN)
84 message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation")
85 list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN})
86 endif()
87
88 if (GGML_METAL_STD)
89 message(STATUS "Adding -std=${GGML_METAL_STD} flag to metal compilation")
90 list (APPEND XC_FLAGS -std=${GGML_METAL_STD})
91 endif()
92
93 add_custom_command(
94 OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
95 COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - |
96 xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
97 COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h
98 COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal
99 DEPENDS ggml-metal.metal ${METALLIB_COMMON}
100 COMMENT "Compiling Metal kernels"
101 )
102
103 # FIXME: only add to the ggml-metal target?
104 add_custom_target(
105 ggml-metal-lib ALL
106 DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
107 )
108endif() # GGML_METAL_EMBED_LIBRARY
109
110if (NOT GGML_METAL_EMBED_LIBRARY)
111 install(
112 FILES src/ggml-metal/ggml-metal.metal
113 PERMISSIONS
114 OWNER_READ
115 OWNER_WRITE
116 GROUP_READ
117 WORLD_READ
118 DESTINATION ${CMAKE_INSTALL_BINDIR})
119
120 install(
121 FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib
122 DESTINATION ${CMAKE_INSTALL_BINDIR}
123 )
124endif()
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp
new file mode 100644
index 0000000..95627d3
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp
@@ -0,0 +1,446 @@
1#include "ggml-metal-common.h"
2
3#include "ggml-impl.h"
4#include "ggml-backend-impl.h"
5
6#include <vector>
7
8// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
9// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
10struct ggml_mem_range {
11 uint64_t pb; // buffer id
12
13 uint64_t p0; // begin
14 uint64_t p1; // end
15
16 ggml_mem_range_type pt;
17};
18
19struct ggml_mem_ranges {
20 std::vector<ggml_mem_range> ranges;
21
22 int debug = 0;
23};
24
25ggml_mem_ranges_t ggml_mem_ranges_init(int debug) {
26 auto * res = new ggml_mem_ranges;
27
28 res->ranges.reserve(256);
29 res->debug = debug;
30
31 return res;
32}
33
34void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) {
35 delete mrs;
36}
37
38void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) {
39 mrs->ranges.clear();
40}
41
42static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
43 mrs->ranges.push_back(mr);
44
45 return true;
46}
47
48static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
49 // always use the base tensor
50 tensor = tensor->view_src ? tensor->view_src : tensor;
51
52 GGML_ASSERT(!tensor->view_src);
53
54 ggml_mem_range mr;
55
56 if (tensor->buffer) {
57 // when the tensor is allocated, use the actual memory address range in the buffer
58 //
59 // take the actual allocated size with ggml_backend_buft_get_alloc_size()
60 // this can be larger than the tensor size if the buffer type allocates extra memory
61 // ref: https://github.com/ggml-org/llama.cpp/pull/15966
62 mr = {
63 /*.pb =*/ (uint64_t) tensor->buffer,
64 /*.p0 =*/ (uint64_t) tensor->data,
65 /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
66 /*.pt =*/ pt,
67 };
68 } else {
69 // otherwise, the pointer address is used as an unique id of the memory ranges
70 // that the tensor will be using when it is allocated
71 mr = {
72 /*.pb =*/ (uint64_t) tensor,
73 /*.p0 =*/ 0, //
74 /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
75 /*.pt =*/ pt,
76 };
77 };
78
79 return mr;
80}
81
82static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
83 return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
84}
85
86static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
87 return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
88}
89
90static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
91 GGML_ASSERT(tensor);
92
93 ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
94
95 if (mrs->debug > 2) {
96 GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
97 }
98
99 return ggml_mem_ranges_add(mrs, mr);
100}
101
102static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
103 GGML_ASSERT(tensor);
104
105 ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
106
107 if (mrs->debug > 2) {
108 GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
109 }
110
111 return ggml_mem_ranges_add(mrs, mr);
112}
113
114bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
115 for (int i = 0; i < GGML_MAX_SRC; i++) {
116 if (tensor->src[i]) {
117 ggml_mem_ranges_add_src(mrs, tensor->src[i]);
118 }
119 }
120
121 return ggml_mem_ranges_add_dst(mrs, tensor);
122}
123
124static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) {
125 for (size_t i = 0; i < mrs->ranges.size(); i++) {
126 const auto & cmp = mrs->ranges[i];
127
128 // two memory ranges cannot intersect if they are in different buffers
129 if (mr.pb != cmp.pb) {
130 continue;
131 }
132
133 // intersecting source ranges are allowed
134 if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
135 continue;
136 }
137
138 if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
139 if (mrs->debug > 2) {
140 GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
141 __func__,
142 mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
143 mr.pb, mr.p0, mr.p1,
144 cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
145 cmp.pb, cmp.p0, cmp.p1);
146 }
147
148 return false;
149 }
150 }
151
152 return true;
153}
154
155static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
156 GGML_ASSERT(tensor);
157
158 ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
159
160 const bool res = ggml_mem_ranges_check(mrs, mr);
161
162 return res;
163}
164
165static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
166 GGML_ASSERT(tensor);
167
168 ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
169
170 const bool res = ggml_mem_ranges_check(mrs, mr);
171
172 return res;
173}
174
175bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) {
176 for (int i = 0; i < GGML_MAX_SRC; i++) {
177 if (tensor->src[i]) {
178 if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
179 return false;
180 }
181 }
182 }
183
184 return ggml_mem_ranges_check_dst(mrs, tensor);
185}
186
187struct node_info {
188 ggml_tensor * node;
189
190 std::vector<ggml_tensor *> fused;
191
192 ggml_op op() const {
193 return node->op;
194 }
195
196 const ggml_tensor * dst() const {
197 return fused.empty() ? node : fused.back();
198 }
199
200 bool is_empty() const {
201 return ggml_op_is_empty(node->op);
202 }
203
204 void add_fused(ggml_tensor * t) {
205 fused.push_back(t);
206 }
207};
208
209static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
210 // helper to add node src and dst ranges
211 const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) {
212 for (int i = 0; i < GGML_MAX_SRC; i++) {
213 if (node.node->src[i]) {
214 if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
215 return false;
216 }
217 }
218 }
219
220 // keep track of the sources of the fused nodes as well
221 for (const auto * fused : node.fused) {
222 for (int i = 0; i < GGML_MAX_SRC; i++) {
223 if (fused->src[i]) {
224 if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
225 return false;
226 }
227 }
228 }
229 }
230
231 return ggml_mem_ranges_add_dst(mrs, node.dst());
232 };
233
234 // helper to check if a node can run concurrently with the existing set of nodes
235 const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) {
236 for (int i = 0; i < GGML_MAX_SRC; i++) {
237 if (node.node->src[i]) {
238 if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
239 return false;
240 }
241 }
242 }
243
244 for (const auto * fused : node.fused) {
245 for (int i = 0; i < GGML_MAX_SRC; i++) {
246 if (fused->src[i]) {
247 if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
248 return false;
249 }
250 }
251 }
252 }
253
254 return ggml_mem_ranges_check_dst(mrs, node.dst());
255 };
256
257 // perform reorders only across these types of ops
258 // can be expanded when needed
259 const auto & h_safe = [](ggml_op op) {
260 switch (op) {
261 case GGML_OP_MUL_MAT:
262 case GGML_OP_MUL_MAT_ID:
263 case GGML_OP_ROPE:
264 case GGML_OP_NORM:
265 case GGML_OP_RMS_NORM:
266 case GGML_OP_GROUP_NORM:
267 case GGML_OP_SUM_ROWS:
268 case GGML_OP_MUL:
269 case GGML_OP_ADD:
270 case GGML_OP_DIV:
271 case GGML_OP_GLU:
272 case GGML_OP_SCALE:
273 case GGML_OP_GET_ROWS:
274 case GGML_OP_CPY:
275 case GGML_OP_SET_ROWS:
276 return true;
277 default:
278 return ggml_op_is_empty(op);
279 }
280 };
281
282 const int n = nodes.size();
283
284 std::vector<int> res;
285 res.reserve(n);
286
287 std::vector<bool> used(n, false);
288
289 // the memory ranges for the set of currently concurrent nodes
290 ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0);
291
292 // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
293 ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0);
294
295 for (int i0 = 0; i0 < n; i0++) {
296 if (used[i0]) {
297 continue;
298 }
299
300 const auto & node0 = nodes[i0];
301
302 // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
303 // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
304 //
305 // note: we can always add empty nodes to the concurrent set as they don't read nor write anything
306 if (!node0.is_empty() && !h_check(mrs0, node0)) {
307 // this will hold the set of memory ranges from the nodes that haven't been processed yet
308 // if a node is not concurrent with this set, we cannot reorder it
309 ggml_mem_ranges_reset(mrs1);
310
311 // initialize it with the current node
312 h_add(mrs1, node0);
313
314 // that many nodes forward to search for a concurrent node
315 constexpr int N_FORWARD = 8;
316
317 for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
318 if (used[i1]) {
319 continue;
320 }
321
322 const auto & node1 = nodes[i1];
323
324 // disallow reordering of certain ops
325 if (!h_safe(node1.op())) {
326 break;
327 }
328
329 const bool is_empty = node1.is_empty();
330
331 // to reorder a node and add it to the concurrent set, it has to be:
332 // + empty or concurrent with all nodes in the existing concurrent set (mrs0)
333 // + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
334 if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
335 // add the node to the existing concurrent set (i.e. reorder it for early execution)
336 h_add(mrs0, node1);
337 res.push_back(i1);
338
339 // mark as used, so we skip re-processing it later
340 used[i1] = true;
341 } else {
342 // expand the set of nodes that haven't been processed yet
343 h_add(mrs1, node1);
344 }
345 }
346
347 // finalize the concurrent set and begin a new one
348 ggml_mem_ranges_reset(mrs0);
349 }
350
351 // expand the concurrent set with the current node
352 {
353 h_add(mrs0, node0);
354 res.push_back(i0);
355 }
356 }
357
358 ggml_mem_ranges_free(mrs0);
359 ggml_mem_ranges_free(mrs1);
360
361 return res;
362}
363
364void ggml_graph_optimize(ggml_cgraph * gf) {
365 constexpr int MAX_FUSE = 16;
366
367 const int n = gf->n_nodes;
368
369 enum ggml_op ops[MAX_FUSE];
370
371 std::vector<node_info> nodes;
372 nodes.reserve(gf->n_nodes);
373
374 // fuse nodes:
375 // we don't want to make reorders that break fusing, so we first pack all fusable tensors
376 // and perform the reorder over the fused nodes. after the reorder is done, we unfuse
377 for (int i = 0; i < n; i++) {
378 node_info node = {
379 /*.node =*/ gf->nodes[i],
380 /*.fused =*/ {},
381 };
382
383 // fuse only ops that start with these operations
384 // can be expanded when needed
385 if (node.op() == GGML_OP_ADD ||
386 node.op() == GGML_OP_NORM ||
387 node.op() == GGML_OP_RMS_NORM) {
388 ops[0] = node.op();
389
390 int f = i + 1;
391 while (f < n && f < i + MAX_FUSE) {
392 // conservatively allow fusing only these ops
393 // can be expanded when needed
394 if (gf->nodes[f]->op != GGML_OP_ADD &&
395 gf->nodes[f]->op != GGML_OP_MUL &&
396 gf->nodes[f]->op != GGML_OP_NORM &&
397 gf->nodes[f]->op != GGML_OP_RMS_NORM) {
398 break;
399 }
400 ops[f - i] = gf->nodes[f]->op;
401 f++;
402 }
403
404 f -= i;
405 for (; f > 1; f--) {
406 if (ggml_can_fuse(gf, i, ops, f)) {
407 break;
408 }
409 }
410
411 // add the fused tensors into the node info so we can unfuse them later
412 for (int k = 1; k < f; k++) {
413 ++i;
414
415 // the .dst() becomes the last fused tensor
416 node.add_fused(gf->nodes[i]);
417 }
418 }
419
420 nodes.push_back(std::move(node));
421 }
422
423#if 1
424 // reorder to improve concurrency
425 const auto order = ggml_metal_graph_optimize_reorder(nodes);
426#else
427 std::vector<int> order(nodes.size());
428 for (size_t i = 0; i < nodes.size(); i++) {
429 order[i] = i;
430 }
431#endif
432
433 // unfuse
434 {
435 int j = 0;
436 for (const auto i : order) {
437 const auto & node = nodes[i];
438
439 gf->nodes[j++] = node.node;
440
441 for (auto * fused : node.fused) {
442 gf->nodes[j++] = fused;
443 }
444 }
445 }
446}
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h
new file mode 100644
index 0000000..3acbc6a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h
@@ -0,0 +1,52 @@
1// helper functions for ggml-metal that are too difficult to implement in Objective-C
2
3#pragma once
4
5#include <stdbool.h>
6
7#ifdef __cplusplus
8extern "C" {
9#endif
10
11struct ggml_tensor;
12struct ggml_cgraph;
13
14enum ggml_mem_range_type {
15 MEM_RANGE_TYPE_SRC = 0,
16 MEM_RANGE_TYPE_DST = 1,
17};
18
19// a helper object that can be used for reordering operations to improve concurrency
20//
21// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
22// don't write to a memory that is being read by another task or written to by another task in the set
23//
24// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
25// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
26// tasks already in the set)
27//
28typedef struct ggml_mem_ranges * ggml_mem_ranges_t;
29
30ggml_mem_ranges_t ggml_mem_ranges_init(int debug);
31void ggml_mem_ranges_free(ggml_mem_ranges_t mrs);
32
33// remove all ranges from the set
34void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs);
35
36// add src or dst ranges to track
37bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
38
39// return false if:
40// - new src range overlaps with any existing dst range
41// - new dst range overlaps with any existing range (src or dst)
42bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor);
43
44// reorder the nodes in the graph to improve concurrency, while respecting fusion
45//
46// note: this implementation is generic and not specific to metal
47// if it proves to work well, we can start using it for other backends in the future
48void ggml_graph_optimize(struct ggml_cgraph * gf);
49
50#ifdef __cplusplus
51}
52#endif
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h
new file mode 100644
index 0000000..abf4b06
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h
@@ -0,0 +1,41 @@
1#pragma once
2
3#include "ggml-metal-device.h"
4
5#ifdef __cplusplus
6extern "C" {
7#endif
8
9//
10// backend context
11//
12
13typedef struct ggml_metal * ggml_metal_t;
14
15ggml_metal_t ggml_metal_init(ggml_metal_device_t dev);
16void ggml_metal_free(ggml_metal_t ctx);
17
18const char * ggml_metal_get_name(ggml_metal_t ctx);
19
20void ggml_metal_synchronize(ggml_metal_t ctx);
21
22void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
23void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
24bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst);
25
26enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf);
27void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf);
28
29void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev);
30void ggml_metal_event_wait (ggml_metal_t ctx, ggml_metal_event_t ev);
31
32ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx);
33
34void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb);
35void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data);
36bool ggml_metal_supports_family (ggml_metal_t ctx, int family);
37void ggml_metal_capture_next_compute(ggml_metal_t ctx);
38
39#ifdef __cplusplus
40}
41#endif
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m
new file mode 100644
index 0000000..5d3a8ce
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m
@@ -0,0 +1,702 @@
1#import "ggml-metal-context.h"
2
3#import "ggml-impl.h"
4#import "ggml-backend-impl.h"
5
6#import "ggml-metal-impl.h"
7#import "ggml-metal-common.h"
8#import "ggml-metal-ops.h"
9
10#import <Foundation/Foundation.h>
11
12#import <Metal/Metal.h>
13
14#undef MIN
15#undef MAX
16#define MIN(a, b) ((a) < (b) ? (a) : (b))
17#define MAX(a, b) ((a) > (b) ? (a) : (b))
18
19// max number of MTLCommandBuffer used to submit a graph for processing
20#define GGML_METAL_MAX_COMMAND_BUFFERS 8
21
22struct ggml_metal_command_buffer {
23 id<MTLCommandBuffer> obj;
24};
25
26struct ggml_metal {
27 char name[128];
28
29 ggml_metal_device_t dev;
30 ggml_metal_library_t lib;
31
32 ggml_metal_event_t ev_cpy; // for async copies
33
34 dispatch_queue_t d_queue;
35
36 // additional, inference-time compiled pipelines
37 ggml_metal_pipelines_t pipelines_ext;
38
39 bool use_fusion;
40 bool use_concurrency;
41 bool use_graph_optimize;
42
43 int debug_graph;
44 int debug_fusion;
45
46 // how many times a given op was fused
47 uint64_t fuse_cnt[GGML_OP_COUNT];
48
49 // capture state
50 bool capture_next_compute;
51 bool capture_started;
52
53 id<MTLCaptureScope> capture_scope;
54
55 // command buffer state
56 int n_cb; // number of extra threads used to submit the command buffers
57 int n_nodes_0; // number of nodes submitted by the main thread
58 int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
59 int n_nodes_per_cb;
60
61 struct ggml_cgraph * gf;
62
63 // the callback given to the thread pool
64 void (^encode_async)(size_t ith);
65
66 // n_cb command buffers + 1 used by the main thread
67 struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
68
69 // extra command buffers for things like getting, setting and copying tensors
70 NSMutableArray * cmd_bufs_ext;
71
72 // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
73 id<MTLCommandBuffer> cmd_buf_last;
74
75 // abort ggml_metal_graph_compute if callback returns true
76 ggml_abort_callback abort_callback;
77 void * abort_callback_data;
78};
79
80ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
81 GGML_LOG_INFO("%s: allocating\n", __func__);
82
83#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
84 // Show all the Metal device instances in the system
85 NSArray * devices = MTLCopyAllDevices();
86 for (id<MTLDevice> device in devices) {
87 GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
88 }
89 [devices release]; // since it was created by a *Copy* C method
90#endif
91
92 // init context
93 ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
94
95 id<MTLDevice> device = ggml_metal_device_get_obj(dev);
96
97 GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
98
99 // TODO: would it be better to have one queue for the backend and one queue for the device?
100 // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
101 //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
102 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(dev);
103 if (queue == nil) {
104 GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
105 return NULL;
106 }
107
108 res->dev = dev;
109 res->lib = ggml_metal_device_get_library(dev);
110 if (res->lib == NULL) {
111 GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__);
112 GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__);
113
114 res->lib = ggml_metal_library_init(dev);
115 if (res->lib == NULL) {
116 GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__);
117
118 free(res);
119
120 return NULL;
121 }
122 }
123
124 res->ev_cpy = ggml_metal_device_event_init(dev);
125
126 const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
127
128 snprintf(res->name, sizeof(res->name), "%s", props_dev->name);
129
130 res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
131
132 res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
133 res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
134
135 {
136 const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
137 res->debug_graph = val ? atoi(val) : 0;
138 }
139
140 {
141 const char * val = getenv("GGML_METAL_FUSION_DEBUG");
142 res->debug_fusion = val ? atoi(val) : 0;
143 }
144
145 res->use_graph_optimize = true;
146
147 if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
148 res->use_graph_optimize = false;
149 }
150
151 memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
152
153 GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
154 GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
155 GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
156
157 res->capture_next_compute = false;
158 res->capture_started = false;
159 res->capture_scope = nil;
160
161 res->gf = nil;
162 res->encode_async = nil;
163 for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
164 res->cmd_bufs[i].obj = nil;
165 }
166
167 res->cmd_bufs_ext = [[NSMutableArray alloc] init];
168
169 res->cmd_buf_last = nil;
170
171 res->pipelines_ext = ggml_metal_pipelines_init();
172
173 return res;
174}
175
176void ggml_metal_free(ggml_metal_t ctx) {
177 GGML_LOG_INFO("%s: deallocating\n", __func__);
178
179 for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
180 if (ctx->cmd_bufs[i].obj) {
181 [ctx->cmd_bufs[i].obj release];
182 }
183 }
184
185 for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {
186 if (ctx->cmd_bufs_ext[i]) {
187 [ctx->cmd_bufs_ext[i] release];
188 }
189 }
190
191 [ctx->cmd_bufs_ext removeAllObjects];
192 [ctx->cmd_bufs_ext release];
193
194 if (ctx->pipelines_ext) {
195 ggml_metal_pipelines_free(ctx->pipelines_ext);
196 ctx->pipelines_ext = nil;
197 }
198
199 if (ctx->debug_fusion > 0) {
200 GGML_LOG_DEBUG("%s: fusion stats:\n", __func__);
201 for (int i = 0; i < GGML_OP_COUNT; i++) {
202 if (ctx->fuse_cnt[i] == 0) {
203 continue;
204 }
205
206 // note: cannot use ggml_log here
207 GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
208 }
209 }
210
211 Block_release(ctx->encode_async);
212
213 //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
214
215 dispatch_release(ctx->d_queue);
216
217 ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy);
218
219 free(ctx);
220}
221
222const char * ggml_metal_get_name(ggml_metal_t ctx) {
223 return ctx->name;
224}
225
226void ggml_metal_synchronize(ggml_metal_t ctx) {
227 // wait for any backend operations to finish
228 if (ctx->cmd_buf_last) {
229 [ctx->cmd_buf_last waitUntilCompleted];
230 ctx->cmd_buf_last = nil;
231 }
232
233 // check status of all command buffers
234 {
235 const int n_cb = ctx->n_cb;
236
237 for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) {
238 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
239 if (!cmd_buf) {
240 continue;
241 }
242
243 MTLCommandBufferStatus status = [cmd_buf status];
244 if (status != MTLCommandBufferStatusCompleted) {
245 GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status);
246 if (status == MTLCommandBufferStatusError) {
247 GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
248 }
249 GGML_ABORT("fatal error");
250 }
251 }
252 }
253
254 // release any completed extra command buffers
255 if (ctx->cmd_bufs_ext.count > 0) {
256 for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
257 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];
258
259 MTLCommandBufferStatus status = [cmd_buf status];
260 if (status != MTLCommandBufferStatusCompleted) {
261 GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
262 if (status == MTLCommandBufferStatusError) {
263 GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
264 }
265 GGML_ABORT("fatal error");
266 }
267
268 [cmd_buf release];
269 }
270
271 [ctx->cmd_bufs_ext removeAllObjects];
272 }
273}
274
275static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) {
276 if (!t) {
277 return (struct ggml_metal_buffer_id) { nil, 0 };
278 }
279
280 ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
281
282 return ggml_metal_buffer_get_id(buffer->context, t);
283}
284
285void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
286 @autoreleasepool {
287 // wrap the source data into a Metal buffer
288 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
289 id<MTLBuffer> buf_src = [device newBufferWithBytes:data
290 length:size
291 options:MTLResourceStorageModeShared];
292
293 GGML_ASSERT(buf_src);
294
295 struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor);
296 if (bid_dst.metal == nil) {
297 GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
298 }
299
300 bid_dst.offs += offset;
301
302 // queue the copy operation into the queue of the Metal context
303 // this will be queued at the end, after any currently ongoing GPU operations
304 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
305 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
306 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
307
308 [encoder copyFromBuffer:buf_src
309 sourceOffset:0
310 toBuffer:bid_dst.metal
311 destinationOffset:bid_dst.offs
312 size:size];
313
314 [encoder endEncoding];
315 [cmd_buf commit];
316 [buf_src release];
317
318 // do not wait here for completion
319 //[cmd_buf waitUntilCompleted];
320
321 // instead, remember a reference to the command buffer and wait for it later if needed
322 [ctx->cmd_bufs_ext addObject:cmd_buf];
323 ctx->cmd_buf_last = cmd_buf;
324
325 [cmd_buf retain];
326 }
327}
328
329void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
330 @autoreleasepool {
331 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
332 id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data
333 length:size
334 options:MTLResourceStorageModeShared
335 deallocator:nil];
336
337 GGML_ASSERT(buf_dst);
338
339 struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor);
340 if (bid_src.metal == nil) {
341 GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
342 }
343
344 bid_src.offs += offset;
345
346 // queue the copy operation into the queue of the Metal context
347 // this will be queued at the end, after any currently ongoing GPU operations
348 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
349 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
350 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
351
352 [encoder copyFromBuffer:bid_src.metal
353 sourceOffset:bid_src.offs
354 toBuffer:buf_dst
355 destinationOffset:0
356 size:size];
357
358 [encoder endEncoding];
359 [cmd_buf commit];
360 [buf_dst release];
361
362 // do not wait here for completion
363 //[cmd_buf waitUntilCompleted];
364
365 // instead, remember a reference to the command buffer and wait for it later if needed
366 [ctx->cmd_bufs_ext addObject:cmd_buf];
367 ctx->cmd_buf_last = cmd_buf;
368
369 [cmd_buf retain];
370 }
371}
372
373bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) {
374 @autoreleasepool {
375 struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src);
376 struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst);
377
378 if (bid_src.metal == nil || bid_dst.metal == nil) {
379 return false;
380 }
381
382 // queue the copy operation into the Metal context
383 // this will be queued at the end, after any currently ongoing GPU operations
384 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx_src->dev);
385 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
386 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
387
388 [encoder copyFromBuffer:bid_src.metal
389 sourceOffset:bid_src.offs
390 toBuffer:bid_dst.metal
391 destinationOffset:bid_dst.offs
392 size:ggml_nbytes(src)];
393
394 [encoder endEncoding];
395
396 ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
397 ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
398
399 [cmd_buf commit];
400
401 // do not wait here for completion
402 //[cmd_buf waitUntilCompleted];
403
404 // instead, remember a reference to the command buffer and wait for it later if needed
405 [ctx_src->cmd_bufs_ext addObject:cmd_buf];
406 ctx_src->cmd_buf_last = cmd_buf;
407
408 [cmd_buf retain];
409
410 ggml_metal_event_wait(ctx_dst, ev_cpy);
411
412 return true;
413 }
414}
415
416enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
417 // number of nodes encoded by the main thread (empirically determined)
418 const int n_main = MAX(64, 0.1*gf->n_nodes);
419
420 // number of threads in addition to the main thread
421 const int n_cb = ctx->n_cb;
422
423 // keep the memory wired
424 ggml_metal_device_rsets_keep_alive(ctx->dev);
425
426 // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
427 // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
428 // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
429 // each thread creates it's own command buffer and enqueues the ops in parallel
430 //
431 // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
432
433 @autoreleasepool {
434 ctx->gf = gf;
435
436 ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
437 ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
438
439 ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
440
441 const bool use_capture = ctx->capture_next_compute;
442 if (use_capture) {
443 ctx->capture_next_compute = false;
444
445 // make sure all previous computations have finished before starting the capture
446 if (ctx->cmd_buf_last) {
447 [ctx->cmd_buf_last waitUntilCompleted];
448 ctx->cmd_buf_last = nil;
449 }
450
451 if (!ctx->capture_started) {
452 // create capture scope
453 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
454 ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device];
455
456 MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
457 descriptor.captureObject = ctx->capture_scope;
458 descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
459 descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
460
461 NSError * error = nil;
462 if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
463 GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
464 } else {
465 [ctx->capture_scope beginScope];
466 ctx->capture_started = true;
467 }
468 }
469 }
470
471 // short-hand
472 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
473
474 // the main thread commits the first few commands immediately
475 // cmd_buf[n_cb]
476 {
477 id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
478 [cmd_buf retain];
479
480 if (ctx->cmd_bufs[n_cb].obj) {
481 [ctx->cmd_bufs[n_cb].obj release];
482 }
483 ctx->cmd_bufs[n_cb].obj = cmd_buf;
484
485 [cmd_buf enqueue];
486
487 ctx->encode_async(n_cb);
488 }
489
490 // remember the command buffer for the next iteration
491 ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
492
493 // prepare the rest of the command buffers asynchronously (optional)
494 // cmd_buf[0.. n_cb)
495 for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
496 id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences];
497 [cmd_buf retain];
498
499 if (ctx->cmd_bufs[cb_idx].obj) {
500 [ctx->cmd_bufs[cb_idx].obj release];
501 }
502 ctx->cmd_bufs[cb_idx].obj = cmd_buf;
503
504 // always enqueue the first two command buffers
505 // enqueue all of the command buffers if we don't need to abort
506 if (cb_idx < 2 || ctx->abort_callback == NULL) {
507 [cmd_buf enqueue];
508
509 // update the pointer to the last queued command buffer
510 // this is needed to implement synchronize()
511 ctx->cmd_buf_last = cmd_buf;
512 }
513 }
514
515 dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
516
517 // for debugging: block until graph is computed
518 //[ctx->cmd_buf_last waitUntilCompleted];
519
520 // enter here only when capturing in order to wait for all computation to finish
521 // otherwise, we leave the graph to compute asynchronously
522 if (!use_capture && ctx->capture_started) {
523 // wait for completion and check status of each command buffer
524 // needed to detect if the device ran out-of-memory for example (#1881)
525 {
526 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
527 [cmd_buf waitUntilCompleted];
528
529 MTLCommandBufferStatus status = [cmd_buf status];
530 if (status != MTLCommandBufferStatusCompleted) {
531 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
532 if (status == MTLCommandBufferStatusError) {
533 GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
534 }
535
536 return GGML_STATUS_FAILED;
537 }
538 }
539
540 for (int i = 0; i < n_cb; ++i) {
541 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
542 [cmd_buf waitUntilCompleted];
543
544 MTLCommandBufferStatus status = [cmd_buf status];
545 if (status != MTLCommandBufferStatusCompleted) {
546 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
547 if (status == MTLCommandBufferStatusError) {
548 GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
549 }
550
551 return GGML_STATUS_FAILED;
552 }
553
554 id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
555 if (!next_buffer) {
556 continue;
557 }
558
559 const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
560 if (next_queued) {
561 continue;
562 }
563
564 if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
565 GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
566 return GGML_STATUS_ABORTED;
567 }
568
569 [next_buffer commit];
570 }
571
572 [ctx->capture_scope endScope];
573 [[MTLCaptureManager sharedCaptureManager] stopCapture];
574 }
575 }
576
577 return GGML_STATUS_SUCCESS;
578}
579
580void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
581 //const int64_t t_start = ggml_time_us();
582
583 if (ctx->use_graph_optimize) {
584 ggml_graph_optimize(gf);
585 }
586
587 //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
588}
589
590void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) {
591 @autoreleasepool {
592 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
593 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
594
595 ggml_metal_event_encode_signal(ev, cmd_buf);
596
597 [cmd_buf commit];
598
599 [ctx->cmd_bufs_ext addObject:cmd_buf];
600 ctx->cmd_buf_last = cmd_buf;
601
602 [cmd_buf retain];
603 }
604}
605
606void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) {
607 @autoreleasepool {
608 id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev);
609 id<MTLCommandBuffer> cmd_buf = [queue commandBuffer];
610
611 ggml_metal_event_encode_wait(ev, cmd_buf);
612
613 [cmd_buf commit];
614
615 [ctx->cmd_bufs_ext addObject:cmd_buf];
616 ctx->cmd_buf_last = cmd_buf;
617
618 [cmd_buf retain];
619 }
620}
621
622ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) {
623 return ctx->ev_cpy;
624}
625
626void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
627 if (ctx->n_cb != n_cb) {
628 ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
629
630 if (ctx->n_cb > 2) {
631 GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
632 }
633 }
634
635 if (ctx->encode_async) {
636 Block_release(ctx->encode_async);
637 }
638
639 ctx->encode_async = Block_copy(^(size_t iter) {
640 const int cb_idx = iter;
641 const int n_cb_l = ctx->n_cb;
642
643 const int n_nodes_0 = ctx->n_nodes_0;
644 const int n_nodes_1 = ctx->n_nodes_1;
645
646 const int n_nodes_per_cb = ctx->n_nodes_per_cb;
647
648 int idx_start = 0;
649 int idx_end = n_nodes_0;
650
651 if (cb_idx < n_cb_l) {
652 idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
653 idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
654 }
655
656 id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
657
658 ggml_metal_op_t ctx_op = ggml_metal_op_init(
659 ctx->dev,
660 cmd_buf,
661 ctx->gf,
662 idx_start,
663 idx_end,
664 ctx->use_fusion,
665 ctx->use_concurrency,
666 ctx->capture_next_compute,
667 ctx->debug_graph,
668 ctx->debug_fusion);
669
670 for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
671 const int res = ggml_metal_op_encode(ctx_op, idx);
672 if (res == 0) {
673 break;
674 }
675
676 idx += res - 1;
677 }
678
679 ggml_metal_op_free(ctx_op);
680
681 if (cb_idx < 2 || ctx->abort_callback == NULL) {
682 [cmd_buf commit];
683 }
684 });
685}
686
687void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) {
688 ctx->abort_callback = abort_callback;
689 ctx->abort_callback_data = user_data;
690}
691
692bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
693 GGML_ASSERT(ctx->dev != nil);
694
695 id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev);
696
697 return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
698}
699
700void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
701 ctx->capture_next_compute = true;
702}
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp
new file mode 100644
index 0000000..517559d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -0,0 +1,1875 @@
1#include "ggml-metal-device.h"
2
3#include "ggml-metal-impl.h"
4
5#include "ggml-impl.h"
6
7#include <cassert>
8#include <memory>
9#include <string>
10#include <unordered_map>
11
12struct ggml_metal_device_deleter {
13 void operator()(ggml_metal_device_t ctx) {
14 ggml_metal_device_free(ctx);
15 }
16};
17
18typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr;
19
20ggml_metal_device_t ggml_metal_device_get(int device) {
21 static std::vector<ggml_metal_device_ptr> devs;
22
23 devs.emplace_back(ggml_metal_device_init(device));
24
25 return devs.back().get();
26}
27
28struct ggml_metal_pipelines {
29 std::unordered_map<std::string, ggml_metal_pipeline_t> data;
30};
31
32ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
33 ggml_metal_pipelines_t res = new ggml_metal_pipelines();
34
35 return res;
36}
37
38void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
39 if (!ppls) {
40 return;
41 }
42
43 for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
44 ggml_metal_pipeline_free(it->second);
45 }
46
47 delete ppls;
48}
49
50void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) {
51 ppls->data[name] = pipeline;
52}
53
54ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) {
55 if (ppls->data.find(name) == ppls->data.end()) {
56 return nullptr;
57 }
58
59 return ppls->data[name];
60}
61
62struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
63 char base[256];
64 char name[256];
65
66 const char * op_str = "undefined";
67 switch (op) {
68 case GGML_OP_ADD_ID: op_str = "add_id"; break;
69 case GGML_OP_CONCAT: op_str = "concat"; break;
70 default: GGML_ABORT("fatal error");
71 };
72
73 snprintf(base, 256, "kernel_%s", op_str);
74 snprintf(name, 256, "%s", base);
75
76 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
77 if (!res.pipeline) {
78 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
79 }
80
81 return res;
82}
83
84ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
85 char base[256];
86 char name[256];
87
88 snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
89 snprintf(name, 256, "%s", base);
90
91 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
92 if (!res.pipeline) {
93 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
94 }
95
96 return res;
97}
98
99ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
100 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
101 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
102
103 const char * pool_str = "undefined";
104 switch (op_pool) {
105 case GGML_OP_POOL_AVG: pool_str = "avg"; break;
106 case GGML_OP_POOL_MAX: pool_str = "max"; break;
107 default: GGML_ASSERT(false && "not implemented");
108 };
109
110 char base[256];
111 char name[256];
112
113 snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
114 snprintf(name, sizeof(name), "%s", base);
115
116 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
117 if (!res.pipeline) {
118 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
119 }
120
121 return res;
122}
123
124ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
125 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
126 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
127
128 const char * pool_str = "undefined";
129 switch (op_pool) {
130 case GGML_OP_POOL_AVG: pool_str = "avg"; break;
131 case GGML_OP_POOL_MAX: pool_str = "max"; break;
132 default: GGML_ASSERT(false && "not implemented");
133 };
134
135 char base[256];
136 char name[256];
137
138 snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
139 snprintf(name, 256, "%s", base);
140
141 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
142 if (!res.pipeline) {
143 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
144 }
145
146 return res;
147}
148
149ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
150 char base[256];
151 char name[256];
152
153 snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
154 snprintf(name, 256, "%s", base);
155
156 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
157 if (!res.pipeline) {
158 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
159 }
160
161 return res;
162}
163
164ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
165 char base[256];
166 char name[256];
167
168 snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
169 snprintf(name, 256, "%s", base);
170
171 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
172 if (!res.pipeline) {
173 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
174 }
175
176 return res;
177}
178
179ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) {
180 char base[256];
181 char name[256];
182
183 const int n = op->src[0]->ne[0];
184
185 snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type));
186 snprintf(name, 256, "%s_n=%d", base, n);
187
188 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
189 if (!res.pipeline) {
190 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
191 }
192
193 res.nsg = 1;
194 res.smem = 0;
195
196 return res;
197}
198
199ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
200 char base[256];
201 char name[256];
202
203 snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
204 snprintf(name, 256, "%s", base);
205
206 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
207 if (!res.pipeline) {
208 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
209 }
210
211 return res;
212}
213
214ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
215 char base[256];
216 char name[256];
217
218 int op_num = -1;
219
220 switch (op->op) {
221 case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
222 case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
223 case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
224 case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
225 case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
226 case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
227 case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
228 case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
229 case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
230 case GGML_OP_UNARY:
231 switch (ggml_get_unary_op(op)) {
232 case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
233 case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
234 case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
235 case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
236 case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
237 case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
238 case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
239 case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
240 case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
241 case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
242 case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
243 case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
244 case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
245 case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
246 case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
247 case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
248 case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
249 default: GGML_ABORT("fatal error");
250 } break;
251 default: GGML_ABORT("fatal error");
252 };
253
254 const char * t0_str = ggml_type_name(op->src[0]->type);
255 const char * t_str = ggml_type_name(op->type);
256
257 const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
258 const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
259
260 snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
261 snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
262
263 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
264 if (!res.pipeline) {
265 ggml_metal_cv_t cv = ggml_metal_cv_init();
266
267 ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
268 ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
269
270 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
271
272 ggml_metal_cv_free(cv);
273 }
274
275 res.c4 = is_c4;
276 res.cnt = is_cnt;
277
278 return res;
279}
280
281ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
282 GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
283
284 char base[256];
285 char name[256];
286
287 const char * op_str = "undefined";
288 switch (op->op) {
289 case GGML_OP_GLU:
290 switch (ggml_get_glu_op(op)) {
291 case GGML_GLU_OP_REGLU: op_str = "reglu"; break;
292 case GGML_GLU_OP_GEGLU: op_str = "geglu"; break;
293 case GGML_GLU_OP_SWIGLU: op_str = "swiglu"; break;
294 case GGML_GLU_OP_SWIGLU_OAI: op_str = "swiglu_oai"; break;
295 case GGML_GLU_OP_GEGLU_ERF: op_str = "geglu_erf"; break;
296 case GGML_GLU_OP_GEGLU_QUICK: op_str = "geglu_quick"; break;
297 default: GGML_ABORT("fatal error");
298 } break;
299 default: GGML_ABORT("fatal error");
300 };
301
302 snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
303 snprintf(name, 256, "%s", base);
304
305 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
306 if (!res.pipeline) {
307 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
308 }
309
310 return res;
311}
312
313ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
314 assert(op->op == GGML_OP_SUM);
315
316 char base[256];
317 char name[256];
318
319 snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
320 snprintf(name, 256, "%s", base);
321
322 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
323 if (!res.pipeline) {
324 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
325 }
326
327 return res;
328}
329
330ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
331 GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
332
333 char base[256];
334 char name[256];
335
336 const char * op_str = "undefined";
337 switch (op->op) {
338 case GGML_OP_SUM_ROWS:
339 op_str = "sum_rows"; break;
340 case GGML_OP_MEAN:
341 op_str = "mean"; break;
342 default: GGML_ABORT("fatal error");
343 };
344
345 snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
346
347 snprintf(name, 256, "%s", base);
348
349 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
350 if (!res.pipeline) {
351 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
352 }
353
354 res.smem = 32*sizeof(float);
355
356 return res;
357}
358
359ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
360 GGML_ASSERT(op->op == GGML_OP_CUMSUM);
361
362 char base[256];
363 char name[256];
364
365 snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
366 snprintf(name, 256, "%s", base);
367
368 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
369 if (!res.pipeline) {
370 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
371 }
372
373 return res;
374}
375
376ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
377 GGML_ASSERT(op->op == GGML_OP_CUMSUM);
378
379 char base[256];
380 char name[256];
381
382 snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
383 snprintf(name, 256, "%s", base);
384
385 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
386 if (!res.pipeline) {
387 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
388 }
389
390 return res;
391}
392
393ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
394 GGML_ASSERT(op->op == GGML_OP_TRI);
395 GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
396
397 char base[256];
398 char name[256];
399
400 const char * op_str = "tri";
401 const int ttype = op->op_params[0];
402
403 snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype);
404
405 snprintf(name, 256, "%s", base);
406
407 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
408 if (!res.pipeline) {
409 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
410 }
411
412 return res;
413}
414
415ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
416 GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
417
418 char base[256];
419 char name[256];
420
421 const char * suffix = "";
422
423 if (op->src[0]->ne[0] % 4 == 0) {
424 suffix = "_4";
425 }
426
427 const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32;
428
429 snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
430 snprintf(name, 256, "%s", base);
431
432 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
433 if (!res.pipeline) {
434 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
435 }
436
437 res.smem = 32*sizeof(float);
438
439 return res;
440}
441
442ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
443 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
444 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
445
446 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
447 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
448
449 char base[256];
450 char name[256];
451
452 const char * suffix = "";
453
454 if (op->src[1]->ne[0] % 4 == 0) {
455 suffix = "_4";
456 }
457
458 snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
459 snprintf(name, 256, "%s", base);
460
461 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
462 if (!res.pipeline) {
463 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
464 }
465
466 return res;
467}
468
469ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) {
470 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
471 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
472
473 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
474 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
475
476 char base[256];
477 char name[256];
478
479 const char * suffix = "";
480 if (op->src[1]->ne[0] % 4 == 0) {
481 suffix = "_4";
482 }
483
484 snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
485 snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs);
486
487 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
488 if (!res.pipeline) {
489 ggml_metal_cv_t cv = ggml_metal_cv_init();
490
491 ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0);
492
493 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
494
495 ggml_metal_cv_free(cv);
496 }
497
498 return res;
499}
500
501ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
502 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
503
504 char base[256];
505 char name[256];
506
507 const int nsg = (ne00 + 31)/32;
508
509 snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
510 snprintf(name, 256, "%s_nsg=%d", base, nsg);
511
512 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
513 if (!res.pipeline) {
514 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
515 }
516
517 // Shared memory layout:
518 // - sgptg * NW floats for partial sums (nsg * 32)
519 // - sgptg floats for shared_x_dt (nsg)
520 // - sgptg floats for shared_dA (nsg)
521 // Total: nsg * (32 + 2) floats
522 res.smem = (32 + 2)*sizeof(float)*nsg;
523
524 return res;
525}
526
527ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
528 char base[256];
529 char name[256];
530
531 const int64_t C = op->ne[0];
532 const int64_t H = op->src[0]->ne[1];
533
534 switch (op->op) {
535 case GGML_OP_RWKV_WKV6:
536 {
537 GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
538 GGML_ASSERT(C % H == 0);
539 GGML_ASSERT(C / H == 64);
540
541 snprintf(base, 256, "kernel_rwkv_wkv6_%s", ggml_type_name(op->src[0]->type));
542 } break;
543 case GGML_OP_RWKV_WKV7:
544 {
545 GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32);
546 GGML_ASSERT(C % H == 0);
547 GGML_ASSERT(C / H == 64);
548
549 snprintf(base, 256, "kernel_rwkv_wkv7_%s", ggml_type_name(op->src[0]->type));
550 } break;
551 default:
552 GGML_ABORT("fatal error");
553 }
554
555 snprintf(name, 256, "%s", base);
556
557 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
558 if (!res.pipeline) {
559 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
560 }
561
562 return res;
563}
564
565ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
566 char base[256];
567 char name[256];
568
569 const int nsg = 8;
570 const int n = op->src[1]->ne[1];
571 const int k = op->src[1]->ne[0];
572
573 snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type));
574 snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k);
575
576 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
577 if (!res.pipeline) {
578 ggml_metal_cv_t cv = ggml_metal_cv_init();
579
580 ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0);
581 ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1);
582 ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2);
583
584 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
585
586 ggml_metal_cv_free(cv);
587 }
588
589 res.nsg = nsg;
590 res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16);
591
592 return res;
593}
594
595ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
596 char base[256];
597 char name[256];
598
599 snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
600 snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
601
602 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
603 if (!res.pipeline) {
604 ggml_metal_cv_t cv = ggml_metal_cv_init();
605
606 ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
607 ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
608
609 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
610
611 ggml_metal_cv_free(cv);
612 }
613
614 return res;
615}
616
617ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
618 char base[256];
619 char name[256];
620
621 const ggml_type tsrc0 = op->src[0]->type;
622 const ggml_type tsrc1 = op->src[1]->type;
623
624 const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
625 const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
626
627 snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
628 snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
629
630 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
631 if (!res.pipeline) {
632 ggml_metal_cv_t cv = ggml_metal_cv_init();
633
634 ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
635 ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
636
637 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
638
639 ggml_metal_cv_free(cv);
640 }
641
642 // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
643 res.smem = bc_out ? 8192 : 4096 + 2048;
644
645 return res;
646}
647
648ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
649 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
650 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
651
652 char base[256];
653 char name[256];
654
655 int nsg = 0; // number of simdgroups
656 int nr0 = 0; // number of src0 rows per simdgroup
657 int nr1 = 1; // number of src1 rows per threadgroup
658
659 size_t smem = 0; // shared memory
660
661 const ggml_type tsrc0 = op->src[0]->type;
662 const ggml_type tsrc1 = op->src[1]->type;
663
664 const char * suffix = "";
665
666 // use custom matrix x vector kernel
667 switch (tsrc0) {
668 case GGML_TYPE_F32:
669 case GGML_TYPE_F16:
670 case GGML_TYPE_BF16:
671 {
672 if (ne00 < 32) {
673 nsg = 1;
674 nr0 = 32;
675 nr1 = 1;
676 suffix = "_short";
677 } else {
678 nsg = std::min(4, (ne00 + 127) / 128);
679 nr0 = 2;
680 nr1 = 1;
681 smem = 32*sizeof(float)*nr0;
682 suffix = ne00 % 4 == 0 ? "_4" : "";
683 }
684 } break;
685 case GGML_TYPE_Q4_0:
686 {
687 nsg = N_SG_Q4_0;
688 nr0 = N_R0_Q4_0;
689 } break;
690 case GGML_TYPE_Q4_1:
691 {
692 nsg = N_SG_Q4_1;
693 nr0 = N_R0_Q4_1;
694 } break;
695 case GGML_TYPE_Q5_0:
696 {
697 nsg = N_SG_Q5_0;
698 nr0 = N_R0_Q5_0;
699 } break;
700 case GGML_TYPE_Q5_1:
701 {
702 nsg = N_SG_Q5_1;
703 nr0 = N_R0_Q5_1;
704 } break;
705 case GGML_TYPE_Q8_0:
706 {
707 nsg = N_SG_Q8_0;
708 nr0 = N_R0_Q8_0;
709 smem = 32*sizeof(float)*N_R0_Q8_0;
710 } break;
711 case GGML_TYPE_MXFP4:
712 {
713 nsg = N_SG_MXFP4;
714 nr0 = N_R0_MXFP4;
715 smem = 32*sizeof(float);
716 } break;
717 case GGML_TYPE_Q2_K:
718 {
719 nsg = N_SG_Q2_K;
720 nr0 = N_R0_Q2_K;
721 } break;
722 case GGML_TYPE_Q3_K:
723 {
724 nsg = N_SG_Q3_K;
725 nr0 = N_R0_Q3_K;
726 } break;
727 case GGML_TYPE_Q4_K:
728 {
729 nsg = N_SG_Q4_K;
730 nr0 = N_R0_Q4_K;
731 } break;
732 case GGML_TYPE_Q5_K:
733 {
734 nsg = N_SG_Q5_K;
735 nr0 = N_R0_Q5_K;
736 } break;
737 case GGML_TYPE_Q6_K:
738 {
739 nsg = N_SG_Q6_K;
740 nr0 = N_R0_Q6_K;
741 } break;
742 case GGML_TYPE_IQ2_XXS:
743 {
744 nsg = N_SG_IQ2_XXS;
745 nr0 = N_R0_IQ2_XXS;
746 smem = 256*8+128;
747 } break;
748 case GGML_TYPE_IQ2_XS:
749 {
750 nsg = N_SG_IQ2_XS;
751 nr0 = N_R0_IQ2_XS;
752 smem = 512*8+128;
753 } break;
754 case GGML_TYPE_IQ3_XXS:
755 {
756 nsg = N_SG_IQ3_XXS;
757 nr0 = N_R0_IQ3_XXS;
758 smem = 256*4+128;
759 } break;
760 case GGML_TYPE_IQ3_S:
761 {
762 nsg = N_SG_IQ3_S;
763 nr0 = N_R0_IQ3_S;
764 smem = 512*4;
765 } break;
766 case GGML_TYPE_IQ2_S:
767 {
768 nsg = N_SG_IQ2_S;
769 nr0 = N_R0_IQ2_S;
770 } break;
771 case GGML_TYPE_IQ1_S:
772 {
773 nsg = N_SG_IQ1_S;
774 nr0 = N_R0_IQ1_S;
775 } break;
776 case GGML_TYPE_IQ1_M:
777 {
778 nsg = N_SG_IQ1_M;
779 nr0 = N_R0_IQ1_M;
780 } break;
781 case GGML_TYPE_IQ4_NL:
782 {
783 nsg = N_SG_IQ4_NL;
784 nr0 = N_R0_IQ4_NL;
785 smem = 32*sizeof(float);
786 } break;
787 case GGML_TYPE_IQ4_XS:
788 {
789 nsg = N_SG_IQ4_XS;
790 nr0 = N_R0_IQ4_XS;
791 smem = 32*sizeof(float);
792 } break;
793 default:
794 {
795 GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0);
796 GGML_ABORT("not implemented");
797 }
798 };
799
800 snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
801 snprintf(name, 256, "%s_nsg=%d", base, nsg);
802
803 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
804 if (!res.pipeline) {
805 ggml_metal_cv_t cv = ggml_metal_cv_init();
806
807 ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
808
809 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
810
811 ggml_metal_cv_free(cv);
812 }
813
814 res.nr0 = nr0;
815 res.nr1 = nr1;
816 res.nsg = nsg;
817 res.smem = smem;
818
819 return res;
820}
821
822ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
823 char base[256];
824 char name[256];
825
826 snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
827 snprintf(name, 256, "%s_ne02=%d", base, ne02);
828
829 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
830 if (!res.pipeline) {
831 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
832 }
833
834 res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
835
836 return res;
837}
838
839ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
840 char base[256];
841 char name[256];
842
843 const ggml_type tsrc0 = op->src[0]->type;
844 const ggml_type tsrc1 = op->src[1]->type;
845
846 const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
847
848 snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
849 snprintf(name, 256, "%s_bci=%d", base, bc_inp);
850
851 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
852 if (!res.pipeline) {
853 ggml_metal_cv_t cv = ggml_metal_cv_init();
854
855 ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
856
857 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
858
859 ggml_metal_cv_free(cv);
860 }
861
862 res.smem = 8192;
863
864 return res;
865}
866
867ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
868 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
869 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
870
871 char base[256];
872 char name[256];
873
874 int nsg = 0; // number of simdgroups
875 int nr0 = 0; // number of src0 rows per simdgroup
876 int nr1 = 1; // number of src1 rows per threadgroup
877
878 size_t smem = 0; // shared memory
879
880 const ggml_type tsrc0 = op->src[0]->type;
881 const ggml_type tsrc1 = op->src[1]->type;
882
883 const char * suffix = "";
884
885 // use custom matrix x vector kernel
886 switch (tsrc0) {
887 case GGML_TYPE_F32:
888 case GGML_TYPE_F16:
889 case GGML_TYPE_BF16:
890 {
891 nsg = std::min(4, (ne00 + 127) / 128);
892 nr0 = 2;
893 nr1 = 1;
894 smem = 32*sizeof(float)*nr0;
895 suffix = ne00 % 4 == 0 ? "_4" : "";
896 } break;
897 case GGML_TYPE_Q4_0:
898 {
899 nsg = N_SG_Q4_0;
900 nr0 = N_R0_Q4_0;
901 } break;
902 case GGML_TYPE_Q4_1:
903 {
904 nsg = N_SG_Q4_1;
905 nr0 = N_R0_Q4_1;
906 } break;
907 case GGML_TYPE_Q5_0:
908 {
909 nsg = N_SG_Q5_0;
910 nr0 = N_R0_Q5_0;
911 } break;
912 case GGML_TYPE_Q5_1:
913 {
914 nsg = N_SG_Q5_1;
915 nr0 = N_R0_Q5_1;
916 } break;
917 case GGML_TYPE_Q8_0:
918 {
919 nsg = N_SG_Q8_0;
920 nr0 = N_R0_Q8_0;
921 smem = 32*sizeof(float)*N_R0_Q8_0;
922 } break;
923 case GGML_TYPE_MXFP4:
924 {
925 nsg = N_SG_MXFP4;
926 nr0 = N_R0_MXFP4;
927 smem = 32*sizeof(float);
928 } break;
929 case GGML_TYPE_Q2_K:
930 {
931 nsg = N_SG_Q2_K;
932 nr0 = N_R0_Q2_K;
933 } break;
934 case GGML_TYPE_Q3_K:
935 {
936 nsg = N_SG_Q3_K;
937 nr0 = N_R0_Q3_K;
938 } break;
939 case GGML_TYPE_Q4_K:
940 {
941 nsg = N_SG_Q4_K;
942 nr0 = N_R0_Q4_K;
943 } break;
944 case GGML_TYPE_Q5_K:
945 {
946 nsg = N_SG_Q5_K;
947 nr0 = N_R0_Q5_K;
948 } break;
949 case GGML_TYPE_Q6_K:
950 {
951 nsg = N_SG_Q6_K;
952 nr0 = N_R0_Q6_K;
953 } break;
954 case GGML_TYPE_IQ2_XXS:
955 {
956 nsg = N_SG_IQ2_XXS;
957 nr0 = N_R0_IQ2_XXS;
958 smem = 256*8+128;
959 } break;
960 case GGML_TYPE_IQ2_XS:
961 {
962 nsg = N_SG_IQ2_XS;
963 nr0 = N_R0_IQ2_XS;
964 smem = 512*8+128;
965 } break;
966 case GGML_TYPE_IQ3_XXS:
967 {
968 nsg = N_SG_IQ3_XXS;
969 nr0 = N_R0_IQ3_XXS;
970 smem = 256*4+128;
971 } break;
972 case GGML_TYPE_IQ3_S:
973 {
974 nsg = N_SG_IQ3_S;
975 nr0 = N_R0_IQ3_S;
976 smem = 512*4;
977 } break;
978 case GGML_TYPE_IQ2_S:
979 {
980 nsg = N_SG_IQ2_S;
981 nr0 = N_R0_IQ2_S;
982 } break;
983 case GGML_TYPE_IQ1_S:
984 {
985 nsg = N_SG_IQ1_S;
986 nr0 = N_R0_IQ1_S;
987 } break;
988 case GGML_TYPE_IQ1_M:
989 {
990 nsg = N_SG_IQ1_M;
991 nr0 = N_R0_IQ1_M;
992 } break;
993 case GGML_TYPE_IQ4_NL:
994 {
995 nsg = N_SG_IQ4_NL;
996 nr0 = N_R0_IQ4_NL;
997 smem = 32*sizeof(float);
998 } break;
999 case GGML_TYPE_IQ4_XS:
1000 {
1001 nsg = N_SG_IQ4_XS;
1002 nr0 = N_R0_IQ4_XS;
1003 smem = 32*sizeof(float);
1004 } break;
1005 default:
1006 {
1007 GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type);
1008 GGML_ABORT("not implemented");
1009 }
1010 };
1011
1012 snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
1013 snprintf(name, 256, "%s_nsg=%d", base, nsg);
1014
1015 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1016 if (!res.pipeline) {
1017 ggml_metal_cv_t cv = ggml_metal_cv_init();
1018
1019 ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
1020
1021 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1022
1023 ggml_metal_cv_free(cv);
1024 }
1025
1026 res.nr0 = nr0;
1027 res.nr1 = nr1;
1028 res.nsg = nsg;
1029 res.smem = smem;
1030
1031 return res;
1032}
1033
1034ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
1035 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
1036 GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
1037 GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
1038
1039 char base[256];
1040 char name[256];
1041
1042 snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
1043 snprintf(name, 256, "%s", base);
1044
1045 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1046 if (!res.pipeline) {
1047 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1048 }
1049
1050 res.smem = 32*(sizeof(float) + sizeof(int32_t));
1051
1052 return res;
1053}
1054
1055ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
1056 assert(op->op == GGML_OP_ARGSORT);
1057
1058 char base[256];
1059 char name[256];
1060
1061 ggml_sort_order order = (ggml_sort_order) op->op_params[0];
1062
1063 const char * order_str = "undefined";
1064 switch (order) {
1065 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1066 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1067 default: GGML_ABORT("fatal error");
1068 };
1069
1070 snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1071 snprintf(name, 256, "%s", base);
1072
1073 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1074 if (!res.pipeline) {
1075 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1076 }
1077
1078 return res;
1079}
1080
1081ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1082 assert(op->op == GGML_OP_ARGSORT);
1083
1084 char base[256];
1085 char name[256];
1086
1087 ggml_sort_order order = (ggml_sort_order) op->op_params[0];
1088
1089 const char * order_str = "undefined";
1090 switch (order) {
1091 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1092 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1093 default: GGML_ABORT("fatal error");
1094 };
1095
1096 snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1097 snprintf(name, 256, "%s", base);
1098
1099 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1100 if (!res.pipeline) {
1101 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1102 }
1103
1104 return res;
1105}
1106
1107// note: reuse the argsort kernel for top_k
1108ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
1109 assert(op->op == GGML_OP_TOP_K);
1110
1111 char base[256];
1112 char name[256];
1113
1114 // note: the top_k kernel is always descending order
1115 ggml_sort_order order = GGML_SORT_ORDER_DESC;
1116
1117 const char * order_str = "undefined";
1118 switch (order) {
1119 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1120 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1121 default: GGML_ABORT("fatal error");
1122 };
1123
1124 snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1125 snprintf(name, 256, "%s", base);
1126
1127 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1128 if (!res.pipeline) {
1129 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1130 }
1131
1132 return res;
1133}
1134
1135ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
1136 assert(op->op == GGML_OP_TOP_K);
1137
1138 char base[256];
1139 char name[256];
1140
1141 ggml_sort_order order = GGML_SORT_ORDER_DESC;
1142
1143 const char * order_str = "undefined";
1144 switch (order) {
1145 case GGML_SORT_ORDER_ASC: order_str = "asc"; break;
1146 case GGML_SORT_ORDER_DESC: order_str = "desc"; break;
1147 default: GGML_ABORT("fatal error");
1148 };
1149
1150 snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
1151 snprintf(name, 256, "%s", base);
1152
1153 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1154 if (!res.pipeline) {
1155 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1156 }
1157
1158 return res;
1159}
1160
1161ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
1162 ggml_metal_library_t lib,
1163 const struct ggml_tensor * op,
1164 bool has_mask,
1165 int32_t ncpsg) {
1166 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1167 GGML_UNUSED(op);
1168
1169 char base[256];
1170 char name[256];
1171
1172 snprintf(base, 256, "kernel_%s",
1173 "flash_attn_ext_pad");
1174
1175 snprintf(name, 256, "%s_mask=%d_ncpsg=%d",
1176 base,
1177 has_mask,
1178 ncpsg);
1179
1180 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1181 if (!res.pipeline) {
1182 ggml_metal_cv_t cv = ggml_metal_cv_init();
1183
1184 ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0);
1185 //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
1186 //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2);
1187 //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3);
1188
1189 //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
1190 //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
1191 //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22);
1192 //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23);
1193 //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
1194 ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
1195
1196 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1197
1198 ggml_metal_cv_free(cv);
1199 }
1200
1201 return res;
1202}
1203
1204ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
1205 ggml_metal_library_t lib,
1206 const struct ggml_tensor * op,
1207 int32_t nqptg,
1208 int32_t ncpsg) {
1209 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1210 GGML_UNUSED(op);
1211
1212 char base[256];
1213 char name[256];
1214
1215 snprintf(base, 256, "kernel_%s",
1216 "flash_attn_ext_blk");
1217
1218 snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d",
1219 base,
1220 nqptg,
1221 ncpsg);
1222
1223 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1224 if (!res.pipeline) {
1225 ggml_metal_cv_t cv = ggml_metal_cv_init();
1226
1227 //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0);
1228 //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
1229 //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2);
1230 //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3);
1231
1232 //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
1233 //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
1234 //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22);
1235 //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23);
1236 ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
1237 ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
1238
1239 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1240
1241 ggml_metal_cv_free(cv);
1242 }
1243
1244 return res;
1245}
1246
1247ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
1248 ggml_metal_library_t lib,
1249 const ggml_tensor * op,
1250 bool has_mask,
1251 bool has_sinks,
1252 bool has_bias,
1253 bool has_scap,
1254 bool has_kvpad,
1255 int32_t nsg) {
1256 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1257
1258 char base[256];
1259 char name[256];
1260
1261 const int32_t dk = (int32_t) op->src[1]->ne[0];
1262 const int32_t dv = (int32_t) op->src[2]->ne[0];
1263
1264 const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
1265 const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
1266
1267 // do bounds checks for the mask?
1268 const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0);
1269
1270 snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1271 "flash_attn_ext",
1272 ggml_type_name(op->src[1]->type),
1273 dk,
1274 dv);
1275
1276 snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d",
1277 base,
1278 has_mask,
1279 has_sinks,
1280 has_bias,
1281 has_scap,
1282 has_kvpad,
1283 bc_mask,
1284 ns10,
1285 ns20,
1286 nsg);
1287
1288 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1289 if (!res.pipeline) {
1290 ggml_metal_cv_t cv = ggml_metal_cv_init();
1291
1292 ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0);
1293 ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
1294 ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2);
1295 ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3);
1296 ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
1297
1298 ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
1299
1300 ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
1301 ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
1302 ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22);
1303
1304 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1305
1306 ggml_metal_cv_free(cv);
1307 }
1308
1309 return res;
1310}
1311
1312ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
1313 ggml_metal_library_t lib,
1314 const ggml_tensor * op,
1315 bool has_mask,
1316 bool has_sinks,
1317 bool has_bias,
1318 bool has_scap,
1319 bool has_kvpad,
1320 int32_t nsg,
1321 int32_t nwg) {
1322 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1323
1324 char base[256];
1325 char name[256];
1326
1327 const int32_t dk = (int32_t) op->src[1]->ne[0];
1328 const int32_t dv = (int32_t) op->src[2]->ne[0];
1329
1330 const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0];
1331 const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0];
1332
1333 snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d",
1334 "flash_attn_ext_vec",
1335 ggml_type_name(op->src[1]->type),
1336 dk,
1337 dv);
1338
1339 snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
1340 base,
1341 has_mask,
1342 has_sinks,
1343 has_bias,
1344 has_scap,
1345 has_kvpad,
1346 ns10,
1347 ns20,
1348 nsg, nwg);
1349
1350 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1351 if (!res.pipeline) {
1352 ggml_metal_cv_t cv = ggml_metal_cv_init();
1353
1354 ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0);
1355 ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
1356 ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2);
1357 ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3);
1358 ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
1359
1360 ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
1361 ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
1362 ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22);
1363 ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23);
1364
1365 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1366
1367 ggml_metal_cv_free(cv);
1368 }
1369
1370 return res;
1371}
1372
1373ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
1374 ggml_metal_library_t lib,
1375 const ggml_tensor * op,
1376 int32_t dv,
1377 int32_t nwg) {
1378 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
1379
1380 char base[256];
1381 char name[256];
1382
1383 snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
1384 snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
1385
1386 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1387 if (!res.pipeline) {
1388 ggml_metal_cv_t cv = ggml_metal_cv_init();
1389
1390 ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
1391 ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
1392
1393 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1394
1395 ggml_metal_cv_free(cv);
1396 }
1397
1398 return res;
1399
1400 GGML_UNUSED(op);
1401}
1402
1403ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
1404 char base[256];
1405 char name[256];
1406
1407 int op_num = -1;
1408
1409 switch (op->op) {
1410 case GGML_OP_ADD: op_num = 0; break;
1411 case GGML_OP_SUB: op_num = 1; break;
1412 case GGML_OP_MUL: op_num = 2; break;
1413 case GGML_OP_DIV: op_num = 3; break;
1414 default: GGML_ABORT("fatal error");
1415 };
1416
1417 const char * t0_str = ggml_type_name(op->src[0]->type);
1418 const char * t1_str = ggml_type_name(op->src[1]->type);
1419 const char * t_str = ggml_type_name(op->type);
1420
1421 const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
1422
1423 const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
1424
1425 snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
1426 snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
1427
1428 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1429 if (!res.pipeline) {
1430 ggml_metal_cv_t cv = ggml_metal_cv_init();
1431
1432 ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1433 ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
1434 ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
1435
1436 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1437
1438 ggml_metal_cv_free(cv);
1439 }
1440
1441 res.c4 = is_c4;
1442 res.cnt = is_rb;
1443
1444 return res;
1445}
1446
1447ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
1448 char base[256];
1449 char name[256];
1450
1451 int op_num = -1;
1452
1453 switch (op) {
1454 case GGML_OP_ADD: op_num = 0; break;
1455 case GGML_OP_SUB: op_num = 1; break;
1456 case GGML_OP_MUL: op_num = 2; break;
1457 case GGML_OP_DIV: op_num = 3; break;
1458 default: GGML_ABORT("fatal error");
1459 };
1460
1461 snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
1462 snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
1463
1464 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1465 if (!res.pipeline) {
1466 ggml_metal_cv_t cv = ggml_metal_cv_init();
1467
1468 ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
1469 ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
1470 ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
1471
1472 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1473
1474 ggml_metal_cv_free(cv);
1475 }
1476
1477 return res;
1478}
1479
1480ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1481 assert(op->op == GGML_OP_L2_NORM);
1482
1483 char base[256];
1484 char name[256];
1485
1486 const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
1487
1488 const char * t0_str = ggml_type_name(op->src[0]->type);
1489 const char * t_str = ggml_type_name(op->type);
1490
1491 snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
1492 snprintf(name, 256, "%s", base);
1493
1494 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1495 if (!res.pipeline) {
1496 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1497 }
1498
1499 res.c4 = is_c4;
1500 res.smem = 32*sizeof(float);
1501
1502 return res;
1503}
1504
1505ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
1506 assert(op->op == GGML_OP_GROUP_NORM);
1507
1508 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1509
1510 char base[256];
1511 char name[256];
1512
1513 snprintf(base, 256, "kernel_group_norm_f32");
1514 snprintf(name, 256, "%s", base);
1515
1516 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1517 if (!res.pipeline) {
1518 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1519 }
1520
1521 res.smem = 32*sizeof(float);
1522
1523 return res;
1524}
1525
1526ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
1527 assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
1528
1529 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
1530
1531 char base[256];
1532 char name[256];
1533
1534 const char * suffix = "";
1535 if (op->ne[0] % 4 == 0) {
1536 suffix = "_4";
1537 }
1538
1539 switch (op->op) {
1540 case GGML_OP_NORM:
1541 switch (n_fuse) {
1542 case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
1543 case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
1544 case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
1545 default: GGML_ABORT("fatal error");
1546 } break;
1547 case GGML_OP_RMS_NORM:
1548 switch (n_fuse) {
1549 case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
1550 case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
1551 case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
1552 default: GGML_ABORT("fatal error");
1553 } break;
1554 default: GGML_ABORT("fatal error");
1555 }
1556
1557 snprintf(name, 256, "%s", base);
1558
1559 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1560 if (!res.pipeline) {
1561 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1562 }
1563
1564 res.smem = 32*sizeof(float);
1565
1566 return res;
1567}
1568
1569ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
1570 assert(op->op == GGML_OP_ROPE);
1571
1572 char base[256];
1573 char name[256];
1574
1575 const int mode = ((const int32_t *) op->op_params)[2];
1576
1577 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
1578 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
1579 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
1580 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
1581
1582 if (is_neox) {
1583 snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type));
1584 } else if ((is_mrope || is_imrope) && !is_vision) {
1585 GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1586 snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type));
1587 } else if (is_vision) {
1588 GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token
1589 snprintf(base, 256, "kernel_rope_vision_%s", ggml_type_name(op->src[0]->type));
1590 } else {
1591 snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type));
1592 }
1593
1594 snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
1595
1596 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1597 if (!res.pipeline) {
1598 ggml_metal_cv_t cv = ggml_metal_cv_init();
1599
1600 ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
1601
1602 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1603
1604 ggml_metal_cv_free(cv);
1605 }
1606
1607 return res;
1608}
1609
1610ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
1611 assert(op->op == GGML_OP_IM2COL);
1612
1613 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1614 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1615 GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
1616
1617 char base[256];
1618 char name[256];
1619
1620 snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
1621 snprintf(name, 256, "%s", base);
1622
1623 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1624 if (!res.pipeline) {
1625 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1626 }
1627
1628 return res;
1629}
1630
1631ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1632 assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
1633
1634 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1635 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1636 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1637 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1638 GGML_ASSERT(op->type == GGML_TYPE_F32);
1639
1640 char base[256];
1641 char name[256];
1642
1643 snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1644 snprintf(name, 256, "%s", base);
1645
1646 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1647 if (!res.pipeline) {
1648 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1649 }
1650
1651 return res;
1652}
1653
1654ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1655 assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
1656
1657 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1658 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
1659 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1660 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1661 GGML_ASSERT(op->type == GGML_TYPE_F32);
1662
1663 char base[256];
1664 char name[256];
1665
1666 snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1667 snprintf(name, 256, "%s", base);
1668
1669 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1670 if (!res.pipeline) {
1671 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1672 }
1673
1674 return res;
1675}
1676
1677ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
1678 assert(op->op == GGML_OP_CONV_2D);
1679
1680 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
1681 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1682 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
1683 GGML_ASSERT(op->type == GGML_TYPE_F32);
1684
1685 char base[256];
1686 char name[256];
1687
1688 snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
1689 snprintf(name, 256, "%s", base);
1690
1691 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1692 if (!res.pipeline) {
1693 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1694 }
1695
1696 return res;
1697}
1698
1699ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
1700 assert(op->op == GGML_OP_UPSCALE);
1701
1702 char base[256];
1703 char name[256];
1704
1705 snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
1706 snprintf(name, 256, "%s", base);
1707
1708 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1709 if (!res.pipeline) {
1710 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1711 }
1712
1713 return res;
1714}
1715
1716ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
1717 assert(op->op == GGML_OP_PAD);
1718
1719 char base[256];
1720 char name[256];
1721
1722 snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
1723 snprintf(name, 256, "%s", base);
1724
1725 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1726 if (res.pipeline) {
1727 return res;
1728 }
1729
1730 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1731
1732 return res;
1733}
1734
1735ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
1736 assert(op->op == GGML_OP_PAD_REFLECT_1D);
1737
1738 char base[256];
1739 char name[256];
1740
1741 snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
1742 snprintf(name, 256, "%s", base);
1743
1744 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1745 if (!res.pipeline) {
1746 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1747 }
1748
1749 return res;
1750}
1751
1752ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
1753 assert(op->op == GGML_OP_ARANGE);
1754
1755 char base[256];
1756 char name[256];
1757
1758 snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
1759 snprintf(name, 256, "%s", base);
1760
1761 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1762 if (!res.pipeline) {
1763 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1764 }
1765
1766 return res;
1767}
1768
1769ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
1770 assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
1771
1772 char base[256];
1773 char name[256];
1774
1775 snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
1776 snprintf(name, 256, "%s", base);
1777
1778 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1779 if (!res.pipeline) {
1780 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1781 }
1782
1783 return res;
1784}
1785
1786ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
1787 assert(op->op == GGML_OP_OPT_STEP_ADAMW);
1788
1789 char base[256];
1790 char name[256];
1791
1792 snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
1793 snprintf(name, 256, "%s", base);
1794
1795 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1796 if (!res.pipeline) {
1797 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1798 }
1799
1800 return res;
1801}
1802
1803ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
1804 assert(op->op == GGML_OP_OPT_STEP_SGD);
1805
1806 char base[256];
1807 char name[256];
1808
1809 snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
1810 snprintf(name, 256, "%s", base);
1811
1812 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1813 if (!res.pipeline) {
1814 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1815 }
1816
1817 return res;
1818}
1819
1820ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
1821 GGML_ASSERT(op->type == GGML_TYPE_I64);
1822
1823 char base[256];
1824 char name[256];
1825
1826 snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
1827 snprintf(name, 256, "%s", base);
1828
1829 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1830 if (!res.pipeline) {
1831 res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1832 }
1833
1834 return res;
1835}
1836
1837ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
1838 assert(op->op == GGML_OP_COUNT_EQUAL);
1839
1840 GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
1841
1842 GGML_ASSERT(op->src[0]->type == op->src[1]->type);
1843 GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
1844 GGML_ASSERT(op->type == GGML_TYPE_I64);
1845
1846 // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
1847 GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
1848
1849 char base[256];
1850 char name[256];
1851
1852 int nsg = 1;
1853 while (32*nsg < ne00 && nsg < 32) {
1854 nsg *= 2;
1855 }
1856
1857 snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
1858 snprintf(name, 256, "%s_nsg=%d", base, nsg);
1859
1860 ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
1861 if (!res.pipeline) {
1862 ggml_metal_cv_t cv = ggml_metal_cv_init();
1863
1864 ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
1865
1866 res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
1867
1868 ggml_metal_cv_free(cv);
1869 }
1870
1871 res.smem = 32 * sizeof(int32_t);
1872 res.nsg = nsg;
1873
1874 return res;
1875}
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h
new file mode 100644
index 0000000..93d7f6a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h
@@ -0,0 +1,290 @@
1#pragma once
2
3#include "ggml.h"
4
5#ifdef __cplusplus
6extern "C" {
7#endif
8
9struct ggml_metal_buffer_id {
10 void * metal; // id<MTLBuffer>
11 size_t offs;
12};
13
14typedef struct ggml_metal_device * ggml_metal_device_t;
15
16//
17// MTLFunctionConstantValues wrapper
18//
19
20typedef struct ggml_metal_cv * ggml_metal_cv_t;
21
22ggml_metal_cv_t ggml_metal_cv_init(void);
23void ggml_metal_cv_free(ggml_metal_cv_t cv);
24
25void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
26void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
27void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
28
29//
30// MTLComputePipelineState wrapper
31//
32
33typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
34
35ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
36void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
37
38// a collection of pipelines
39typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
40
41ggml_metal_pipelines_t ggml_metal_pipelines_init(void);
42void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
43
44void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
45ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
46
47struct ggml_metal_pipeline_with_params {
48 ggml_metal_pipeline_t pipeline;
49
50 int nsg;
51
52 int nr0;
53 int nr1;
54
55 size_t smem;
56
57 bool c4;
58 bool cnt;
59};
60
61int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
62
63//
64// MTLCommandBuffer wrapper
65//
66
67typedef void * ggml_metal_cmd_buf_t;
68
69//
70// MTLComputeCommandEncoder wrapper
71//
72
73typedef struct ggml_metal_encoder * ggml_metal_encoder_t;
74
75ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent);
76void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
77
78void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
79void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
80
81void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);
82
83void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
84void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
85
86void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx);
87
88void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2);
89
90void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder);
91
92void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
93
94//
95// MTLLibrary wrapper
96//
97
98typedef struct ggml_metal_library * ggml_metal_library_t;
99
100ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev);
101ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose);
102
103void ggml_metal_library_free(ggml_metal_library_t lib);
104
105struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name);
106struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
107
108struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op);
109struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
110struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
111struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
112struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
113struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
114struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op);
115struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
116struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
117struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
118struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
119struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
120struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op);
121struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op);
122struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
123struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
124struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
125struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
126struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
127struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
128struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
129struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
130struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
131struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op);
132struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20);
133struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
134struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
135struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
136struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
137struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
138struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
139struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
140struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
141struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op);
142struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
143struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
144struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
145struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
146struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op);
147struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
148struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
149struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
150struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op);
151struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op);
152struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
153struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
154struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
155struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
156struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
157struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
158struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
159
160struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
161 ggml_metal_library_t lib,
162 const struct ggml_tensor * op,
163 bool has_mask,
164 int32_t ncpsg);
165
166struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
167 ggml_metal_library_t lib,
168 const struct ggml_tensor * op,
169 int32_t nqptg,
170 int32_t ncpsg);
171
172struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
173 ggml_metal_library_t lib,
174 const struct ggml_tensor * op,
175 bool has_mask,
176 bool has_sinks,
177 bool has_bias,
178 bool has_scap,
179 bool has_kvpad,
180 int32_t nsg);
181
182struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
183 ggml_metal_library_t lib,
184 const struct ggml_tensor * op,
185 bool has_mask,
186 bool has_sinks,
187 bool has_bias,
188 bool has_scap,
189 bool has_kvpad,
190 int32_t nsg,
191 int32_t nwg);
192
193struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
194 ggml_metal_library_t lib,
195 const struct ggml_tensor * op,
196 int32_t dv,
197 int32_t nwg);
198
199// MTLResidencySet wrapper
200
201typedef void * ggml_metal_rset_t;
202
203// a collection of residency sets (non-owning)
204typedef struct ggml_metal_rsets * ggml_metal_rsets_t;
205
206ggml_metal_rsets_t ggml_metal_rsets_init(void);
207void ggml_metal_rsets_free(ggml_metal_rsets_t rsets);
208
209//
210// device
211//
212
213struct ggml_metal_device_props {
214 int device;
215 char name[128];
216 char desc[128];
217
218 size_t max_buffer_size;
219 size_t max_working_set_size;
220 size_t max_theadgroup_memory_size;
221
222 bool has_simdgroup_reduction;
223 bool has_simdgroup_mm;
224 bool has_unified_memory;
225 bool has_bfloat;
226 bool has_tensor;
227 bool use_residency_sets;
228 bool use_shared_buffers;
229
230 bool supports_gpu_family_apple7;
231
232 int op_offload_min_batch_size;
233};
234
235typedef struct ggml_metal_event * ggml_metal_event_t;
236
237void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
238void ggml_metal_event_encode_wait (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
239
240ggml_metal_device_t ggml_metal_device_init(int device);
241void ggml_metal_device_free(ggml_metal_device_t dev);
242
243ggml_metal_device_t ggml_metal_device_get(int device);
244
245void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id<MTLDevice>
246void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue>
247
248ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev);
249
250void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset);
251void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset);
252
253void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev);
254
255ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev);
256void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev);
257void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev);
258
259void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
260bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
261
262const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev);
263
264//
265// device buffers
266//
267
268typedef struct ggml_metal_buffer * ggml_metal_buffer_t;
269
270ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared);
271ggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size);
272
273void ggml_metal_buffer_free (ggml_metal_buffer_t buf);
274void * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf);
275bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
276
277void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
278void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
279void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
280void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value);
281
282// finds the Metal buffer that contains the tensor data on the GPU device
283// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
284// Metal buffer based on the host memory pointer
285//
286struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t);
287
288#ifdef __cplusplus
289}
290#endif
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m
new file mode 100644
index 0000000..4ea0bfb
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m
@@ -0,0 +1,1748 @@
1#import "ggml-metal-device.h"
2
3#import "ggml-impl.h"
4
5#include <Foundation/Foundation.h>
6
7#include <Metal/Metal.h>
8
9#include <stdatomic.h>
10
11#ifndef TARGET_OS_VISION
12#define TARGET_OS_VISION 0
13#endif
14
15// create residency sets only on macOS >= 15.0
16#if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
17 TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
18 TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
19 TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
20#define GGML_METAL_HAS_RESIDENCY_SETS 1
21#endif
22
23// overload of MTLGPUFamilyMetalX (not available in some environments)
24static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
25static const NSInteger MTLGPUFamilyMetal4_GGML = 5002;
26
27#if !GGML_METAL_EMBED_LIBRARY
28// Here to assist with NSBundle Path Hack
29@interface GGMLMetalClass : NSObject
30@end
31@implementation GGMLMetalClass
32@end
33#endif
34
35//
36// MTLFunctionConstantValues wrapper
37//
38
39struct ggml_metal_cv {
40 MTLFunctionConstantValues * obj;
41};
42
43ggml_metal_cv_t ggml_metal_cv_init(void) {
44 ggml_metal_cv_t res = calloc(1, sizeof(struct ggml_metal_cv));
45
46 res->obj = [[MTLFunctionConstantValues alloc] init];
47
48 return res;
49}
50
51void ggml_metal_cv_free(ggml_metal_cv_t cv) {
52 [cv->obj release];
53 free(cv);
54}
55
56void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
57 [cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
58}
59
60void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
61 [cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
62}
63
64void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
65 [cv->obj setConstantValue:&value type:MTLDataTypeBool atIndex:idx];
66}
67
68//
69// MTLComputePipelineState wrapper
70//
71
72struct ggml_metal_pipeline {
73 id<MTLComputePipelineState> obj;
74};
75
76ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
77 ggml_metal_pipeline_t res = calloc(1, sizeof(struct ggml_metal_pipeline));
78
79 *res = (struct ggml_metal_pipeline) {
80 /*.obj =*/ nil,
81 };
82
83 return res;
84}
85
86void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
87 [pipeline->obj release];
88
89 free(pipeline);
90}
91
92int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {
93 return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
94}
95
96struct ggml_metal_library {
97 id<MTLLibrary> obj;
98 id<MTLDevice> device;
99
100 ggml_metal_pipelines_t pipelines; // cache of compiled pipelines
101
102 NSLock * lock;
103};
104
105ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) {
106 id<MTLLibrary> library = nil;
107 id<MTLDevice> device = ggml_metal_device_get_obj(dev);
108
109 // load library
110 //
111 // - first check if the library is embedded
112 // - then check if the library is in the bundle
113 // - if not found, load the source and compile it
114 // - if that fails, return NULL
115 //
116 // TODO: move to a function
117 {
118 const int64_t t_start = ggml_time_us();
119
120 NSError * error = nil;
121 NSString * src = nil;
122
123#if GGML_METAL_EMBED_LIBRARY
124 GGML_LOG_INFO("%s: using embedded metal library\n", __func__);
125
126 extern const char ggml_metallib_start[];
127 extern const char ggml_metallib_end[];
128
129 src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding];
130#else
131
132#ifdef SWIFT_PACKAGE
133 NSBundle * bundle = SWIFTPM_MODULE_BUNDLE;
134#else
135 NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
136#endif
137
138 NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"];
139 if (path_lib == nil) {
140 // Try to find the resource in the directory where the current binary located.
141 NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0];
142 NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent];
143
144 NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]];
145 if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
146 GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]);
147
148 NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error];
149 if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) {
150 // Optionally, if this is a symlink, try to resolve it.
151 path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error];
152 if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) {
153 // It is a relative path, adding the binary directory as directory prefix.
154 path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]];
155 }
156 if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) {
157 // Link to the resource could not be resolved.
158 path_lib_default = nil;
159 } else {
160 GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]);
161 }
162 }
163 } else {
164 // The resource couldn't be found in the binary's directory.
165 path_lib_default = nil;
166 }
167
168 path_lib = path_lib_default;
169 }
170
171 if (path_lib != nil) {
172 // pre-compiled library found
173 NSURL * libURL = [NSURL fileURLWithPath:path_lib];
174 GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
175
176 library = [device newLibraryWithURL:libURL error:&error];
177 if (error) {
178 GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
179 return nil;
180 }
181 } else {
182 GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
183
184 NSString * path_source;
185 NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"];
186
187 GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil");
188
189 if (path_resource) {
190 path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"];
191 } else {
192 path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
193 }
194
195 if (path_source == nil) {
196 GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__);
197 path_source = @"ggml-metal.metal";
198 }
199
200 GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]);
201
202 src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error];
203 if (error) {
204 GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
205 return nil;
206 }
207 }
208#endif
209
210 if (!library) {
211 @autoreleasepool {
212 // dictionary of preprocessor macros
213 NSMutableDictionary * prep = [NSMutableDictionary dictionary];
214
215 if (ggml_metal_device_get_props(dev)->has_bfloat) {
216 [prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"];
217 }
218
219 if (ggml_metal_device_get_props(dev)->has_tensor) {
220 [prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"];
221 }
222
223#if GGML_METAL_EMBED_LIBRARY
224 [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"];
225#endif
226
227 MTLCompileOptions * options = [MTLCompileOptions new];
228 options.preprocessorMacros = prep;
229
230 //[options setFastMathEnabled:false];
231
232 library = [device newLibraryWithSource:src options:options error:&error];
233 if (error) {
234 GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
235 return nil;
236 }
237
238#if !__has_feature(objc_arc)
239 [options release];
240#endif
241 }
242 }
243
244#if GGML_METAL_EMBED_LIBRARY
245 [src release];
246#endif // GGML_METAL_EMBED_LIBRARY
247
248 GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
249 }
250
251 ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
252
253 res->obj = library;
254 res->device = device;
255 res->pipelines = ggml_metal_pipelines_init();
256 res->lock = [NSLock new];
257
258 return res;
259}
260
261ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) {
262 if (source == NULL) {
263 GGML_LOG_ERROR("%s: source is NULL\n", __func__);
264 return NULL;
265 }
266
267 id<MTLDevice> device = ggml_metal_device_get_obj(dev);
268 id<MTLLibrary> library = nil;
269 NSError * error = nil;
270
271 const int64_t t_start = ggml_time_us();
272
273 NSString * src = [[NSString alloc] initWithBytes:source
274 length:strlen(source)
275 encoding:NSUTF8StringEncoding];
276 if (!src) {
277 GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__);
278 return NULL;
279 }
280
281 @autoreleasepool {
282 NSMutableDictionary * prep = [NSMutableDictionary dictionary];
283
284 MTLCompileOptions * options = [MTLCompileOptions new];
285 options.preprocessorMacros = prep;
286
287 library = [device newLibraryWithSource:src options:options error:&error];
288 if (error) {
289 if (verbose) {
290 GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]);
291 } else {
292 GGML_LOG_ERROR("%s: error compiling source\n", __func__);
293 }
294 library = nil;
295 }
296
297 [options release];
298 }
299
300 [src release];
301
302 if (!library) {
303 if (verbose) {
304 GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__);
305 }
306
307 return NULL;
308 }
309
310 if (verbose) {
311 GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6);
312 }
313
314 ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library));
315 if (!res) {
316 GGML_LOG_ERROR("%s: calloc failed\n", __func__);
317 return NULL;
318 }
319
320 res->obj = library;
321 res->device = device;
322 res->pipelines = ggml_metal_pipelines_init();
323 res->lock = [NSLock new];
324
325 return res;
326}
327
328void ggml_metal_library_free(ggml_metal_library_t lib) {
329 if (!lib) {
330 return;
331 }
332
333 if (lib->obj) {
334 [lib->obj release];
335 }
336
337 ggml_metal_pipelines_free(lib->pipelines);
338
339 [lib->lock release];
340
341 free(lib);
342}
343
344struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
345 [lib->lock lock];
346
347 struct ggml_metal_pipeline_with_params res = {
348 /*.pipeline =*/ nil,
349 /*.nsg =*/ 0,
350 /*.nr0 =*/ 0,
351 /*.nr1 =*/ 0,
352 /*.smem =*/ 0,
353 /*.c4 =*/ false,
354 /*.cnt =*/ false,
355 };
356
357 res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
358
359 [lib->lock unlock];
360
361 return res;
362}
363
364struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
365 struct ggml_metal_pipeline_with_params res = {
366 /*.pipeline =*/ nil,
367 /*.nsg =*/ 0,
368 /*.nr0 =*/ 0,
369 /*.nr1 =*/ 0,
370 /*.smem =*/ 0,
371 /*.c4 =*/ false,
372 /*.cnt =*/ false,
373 };
374
375 [lib->lock lock];
376
377 res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
378 if (res.pipeline) {
379 [lib->lock unlock];
380
381 return res;
382 }
383
384 @autoreleasepool {
385 NSError * error = nil;
386
387 NSString * base_func = [NSString stringWithUTF8String:base];
388
389 GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name);
390
391 id<MTLFunction> mtl_function;
392 if (!cv) {
393 mtl_function = [lib->obj newFunctionWithName:base_func];
394 } else {
395 mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error];
396 }
397 if (!mtl_function) {
398 [lib->lock unlock];
399
400 GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name);
401 if (error) {
402 GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
403 }
404
405 return res;
406 }
407
408 id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
409
410 [mtl_function release];
411
412 if (!obj) {
413 [lib->lock unlock];
414
415 GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name);
416 if (error) {
417 GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
418 }
419
420 return res;
421 }
422
423 GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name,
424 (void *) obj,
425 (int) obj.maxTotalThreadsPerThreadgroup,
426 (int) obj.threadExecutionWidth);
427
428 if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {
429 [obj release];
430
431 [lib->lock unlock];
432
433 GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
434
435 return res;
436 }
437
438 res.pipeline = ggml_metal_pipeline_init();
439 res.pipeline->obj = obj;
440
441 ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);
442 }
443
444 [lib->lock unlock];
445
446 return res;
447}
448
449//
450// MTLComputeCommandEncoder wrapper
451//
452
453struct ggml_metal_encoder {
454 id<MTLComputeCommandEncoder> obj;
455};
456
457ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent) {
458 ggml_metal_encoder_t res = calloc(1, sizeof(struct ggml_metal_encoder));
459
460 id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw;
461
462 if (concurrent) {
463 res->obj = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
464 } else {
465 res->obj = [cmd_buf computeCommandEncoder];
466 }
467
468 [res->obj retain];
469
470 return res;
471}
472
473void ggml_metal_encoder_free(ggml_metal_encoder_t encoder) {
474 [encoder->obj release];
475 free(encoder);
476}
477
478void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name) {
479 [encoder->obj pushDebugGroup:[NSString stringWithCString:name encoding:NSUTF8StringEncoding]];
480}
481
482void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {
483 [encoder->obj popDebugGroup];
484}
485
486void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {
487 [encoder->obj setComputePipelineState:pipeline.pipeline->obj];
488}
489
490void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {
491 [encoder->obj setBytes:data length:size atIndex:idx];
492}
493
494void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx) {
495 [encoder->obj setBuffer:buffer.metal offset:buffer.offs atIndex:idx];
496}
497
498void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx) {
499 [encoder->obj setThreadgroupMemoryLength:size atIndex:idx];
500}
501
502void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2) {
503 [encoder->obj dispatchThreadgroups:MTLSizeMake(tg0, tg1, tg2) threadsPerThreadgroup:MTLSizeMake(tptg0, tptg1, tptg2)];
504}
505
506void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder) {
507 [encoder->obj memoryBarrierWithScope:MTLBarrierScopeBuffers];
508}
509
510void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) {
511 [encoder->obj endEncoding];
512}
513
514struct ggml_metal_device {
515 id<MTLDevice> mtl_device;
516
517 // a single global queue shared by all Metal backends
518 // technically not needed for devices with unified memory, but enables discrete GPUs support
519 // ref: https://github.com/ggml-org/llama.cpp/pull/15906
520 id<MTLCommandQueue> mtl_queue;
521
522 ggml_metal_rsets_t rsets;
523
524 ggml_metal_library_t library;
525
526 struct ggml_metal_device_props props;
527
528 // virtual address for GPU memory allocations
529 atomic_uintptr_t addr_virt;
530};
531
532//
533// MTLResidenceSet wrapper
534//
535
536struct ggml_metal_rsets {
537 NSLock * lock;
538
539 NSMutableArray * data;
540
541 // number of seconds since the last graph computation
542 // keep the residency sets wired for that amount of time to avoid being collected by the OS
543 int keep_alive_s;
544
545 // background heartbeat thread to keep the residency sets alive
546 atomic_bool d_stop;
547 atomic_int d_loop;
548
549 dispatch_group_t d_group;
550};
551
552ggml_metal_rsets_t ggml_metal_rsets_init(void) {
553 ggml_metal_rsets_t res = calloc(1, sizeof(struct ggml_metal_rsets));
554
555 res->lock = [[NSLock alloc] init];
556 res->data = [[NSMutableArray alloc] init];
557
558 // by default keep the memory wired for 3 minutes
559 res->keep_alive_s = 3*60;
560
561 const char * GGML_METAL_RESIDENCY_KEEP_ALIVE_S = getenv("GGML_METAL_RESIDENCY_KEEP_ALIVE_S");
562 if (GGML_METAL_RESIDENCY_KEEP_ALIVE_S) {
563 res->keep_alive_s = atoi(GGML_METAL_RESIDENCY_KEEP_ALIVE_S);
564 }
565
566 if (res->keep_alive_s <= 0) {
567 res->keep_alive_s = 3*60;
568 }
569
570 GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s);
571
572 atomic_store_explicit(&res->d_stop, false, memory_order_relaxed);
573 atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed);
574
575 res->d_group = dispatch_group_create();
576
577 // start a background thread that periodically requests residency for all the currently active sets in the collection
578 // the requests stop after a certain amount of time (keep_alive_s) of inactivity
579 dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0);
580 dispatch_group_async(res->d_group, d_queue, ^{
581#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
582 if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
583 while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) {
584 if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) {
585 [res->lock lock];
586
587 for (int i = 0; i < (int) res->data.count; ++i) {
588 [res->data[i] requestResidency];
589 }
590
591 atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed);
592
593 [res->lock unlock];
594 }
595
596 // half a second
597 usleep(500 * 1000);
598 }
599 }
600#endif
601 });
602
603 return res;
604}
605
606void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) {
607 if (rsets == NULL) {
608 return;
609 }
610
611 // note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting
612 GGML_ASSERT([rsets->data count] == 0);
613
614 atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed);
615
616 dispatch_group_wait(rsets->d_group, DISPATCH_TIME_FOREVER);
617 dispatch_release(rsets->d_group);
618
619 [rsets->data release];
620 [rsets->lock release];
621
622 free(rsets);
623}
624
625ggml_metal_device_t ggml_metal_device_init(int device) {
626 ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device));
627
628 assert(dev != NULL);
629
630 if (dev->mtl_device == nil) {
631 dev->mtl_device = MTLCreateSystemDefaultDevice();
632
633 if (dev->mtl_device) {
634 dev->mtl_queue = [dev->mtl_device newCommandQueue];
635 if (dev->mtl_queue == nil) {
636 GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
637 }
638
639 dev->addr_virt = 0x000000400ULL;
640
641 dev->props.device = device;
642 dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
643 dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
644
645 dev->props.has_simdgroup_mm = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
646 dev->props.has_unified_memory = dev->mtl_device.hasUnifiedMemory;
647
648 dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
649 dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6];
650 if (getenv("GGML_METAL_BF16_DISABLE") != NULL) {
651 dev->props.has_bfloat = false;
652 }
653
654 dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML];
655 if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) {
656 dev->props.has_tensor = false;
657 }
658
659 // note: disable the tensor API by default for old chips because with the current implementation it is not useful
660 // - M2 Ultra: ~5% slower
661 // - M4, M4 Max: no significant difference
662 //
663 // TODO: try to update the tensor API kernels to at least match the simdgroup performance
664 if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL &&
665 ![[dev->mtl_device name] containsString:@"M5"] &&
666 ![[dev->mtl_device name] containsString:@"M6"] &&
667 ![[dev->mtl_device name] containsString:@"A19"] &&
668 ![[dev->mtl_device name] containsString:@"A20"]) {
669 GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__);
670 dev->props.has_tensor = false;
671 }
672
673 // double-check that the tensor API compiles
674 if (dev->props.has_tensor) {
675 const char * src_tensor_f16 = "\n"
676 "#include <metal_stdlib> \n"
677 "#include <metal_tensor> \n"
678 "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
679 " \n"
680 "using namespace metal; \n"
681 "using namespace mpp::tensor_ops; \n"
682 " \n"
683 "kernel void dummy_kernel( \n"
684 " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n"
685 " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n"
686 " device float * C [[buffer(2)]], \n"
687 " uint2 tgid [[threadgroup_position_in_grid]]) \n"
688 "{ \n"
689 " auto tA = A.slice(0, (int)tgid.y); \n"
690 " auto tB = B.slice((int)tgid.x, 0); \n"
691 " \n"
692 " matmul2d< \n"
693 " matmul2d_descriptor(8, 8, dynamic_extent), \n"
694 " execution_simdgroups<4>> mm; \n"
695 " \n"
696 " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
697 " \n"
698 " auto sA = tA.slice(0, 0); \n"
699 " auto sB = tB.slice(0, 0); \n"
700 " mm.run(sB, sA, cT); \n"
701 " \n"
702 " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
703 " \n"
704 " cT.store(tC); \n"
705 "}";
706
707 GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__);
708 ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false);
709 if (lib == NULL) {
710 GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
711 dev->props.has_tensor = false;
712 } else {
713 struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
714 if (!ppl.pipeline) {
715 GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
716 dev->props.has_tensor = false;
717 }
718
719 ggml_metal_library_free(lib);
720 }
721 }
722
723 // try to compile a dummy kernel to determine if the tensor API is supported for bfloat
724 if (dev->props.has_tensor && dev->props.has_bfloat) {
725 const char * src_tensor_bf16 = "\n"
726 "#include <metal_stdlib> \n"
727 "#include <metal_tensor> \n"
728 "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n"
729 " \n"
730 "using namespace metal; \n"
731 "using namespace mpp::tensor_ops; \n"
732 " \n"
733 "kernel void dummy_kernel( \n"
734 " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n"
735 " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n"
736 " device float * C [[buffer(2)]], \n"
737 " uint2 tgid [[threadgroup_position_in_grid]]) \n"
738 "{ \n"
739 " auto tA = A.slice(0, (int)tgid.y); \n"
740 " auto tB = B.slice((int)tgid.x, 0); \n"
741 " \n"
742 " matmul2d< \n"
743 " matmul2d_descriptor(8, 8, dynamic_extent), \n"
744 " execution_simdgroups<4>> mm; \n"
745 " \n"
746 " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n"
747 " \n"
748 " auto sA = tA.slice(0, 0); \n"
749 " auto sB = tB.slice(0, 0); \n"
750 " mm.run(sB, sA, cT); \n"
751 " \n"
752 " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n"
753 " \n"
754 " cT.store(tC); \n"
755 "}";
756
757 GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__);
758 ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false);
759 if (lib == NULL) {
760 GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
761 dev->props.has_bfloat = false;
762 } else {
763 struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
764 if (!ppl.pipeline) {
765 GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
766 dev->props.has_bfloat = false;
767 }
768
769 ggml_metal_library_free(lib);
770 }
771 }
772
773 dev->props.use_residency_sets = true;
774#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
775 dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil;
776#endif
777
778 dev->props.use_shared_buffers = dev->props.has_unified_memory;
779#if TARGET_OS_OSX
780 // In case of eGPU, shared memory may be preferable.
781 dev->props.use_shared_buffers |= [dev->mtl_device location] == MTLDeviceLocationExternal;
782#endif
783 if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) {
784 dev->props.use_shared_buffers = false;
785 }
786 if (getenv("GGML_METAL_SHARED_BUFFERS_ENABLE") != NULL) {
787 dev->props.use_shared_buffers = true;
788 }
789
790 dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7];
791
792 dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
793
794 dev->props.max_buffer_size = dev->mtl_device.maxBufferLength;
795 dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength;
796 if (@available(macOS 10.12, iOS 16.0, *)) {
797 dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize;
798 } else {
799 dev->props.max_working_set_size = dev->mtl_device.maxBufferLength;
800 }
801
802 snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device);
803 snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]);
804
805 dev->library = ggml_metal_library_init(dev);
806 if (!dev->library) {
807 GGML_LOG_ERROR("%s: error: failed to create library\n", __func__);
808 }
809
810 if (dev->props.use_residency_sets) {
811 dev->rsets = ggml_metal_rsets_init();
812 } else {
813 dev->rsets = nil;
814 }
815
816 // print MTL GPU family:
817 GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name);
818
819 // determine max supported GPU family
820 // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
821 // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
822 {
823 for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
824 if ([dev->mtl_device supportsFamily:i]) {
825 GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
826 break;
827 }
828 }
829
830 for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
831 if ([dev->mtl_device supportsFamily:i]) {
832 GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
833 break;
834 }
835 }
836
837 for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
838 if ([dev->mtl_device supportsFamily:i]) {
839 GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
840 break;
841 }
842 }
843 }
844
845 GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, dev->props.has_simdgroup_reduction ? "true" : "false");
846 GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false");
847 GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false");
848 GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false");
849 GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false");
850 GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false");
851 GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false");
852
853#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
854 if (@available(macOS 10.12, iOS 16.0, *)) {
855 GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, dev->props.max_working_set_size / 1e6);
856 }
857#endif
858 }
859 }
860
861 return dev;
862}
863
864void ggml_metal_device_free(ggml_metal_device_t dev) {
865 assert(dev != NULL);
866
867 ggml_metal_rsets_free(dev->rsets);
868
869 ggml_metal_library_free(dev->library);
870 dev->library = NULL;
871
872 if (dev->mtl_queue) {
873 [dev->mtl_queue release];
874 dev->mtl_queue = nil;
875 }
876
877 if (dev->mtl_device) {
878 [dev->mtl_device release];
879 dev->mtl_device = nil;
880 }
881
882 free(dev);
883}
884
885void * ggml_metal_device_get_obj(ggml_metal_device_t dev) {
886 return dev->mtl_device;
887}
888
889void * ggml_metal_device_get_queue(ggml_metal_device_t dev) {
890 return dev->mtl_queue;
891}
892
893ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) {
894 return dev->library;
895}
896
897void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset) {
898 if (rset == nil) {
899 return;
900 }
901
902 GGML_ASSERT(dev->rsets);
903
904 [dev->rsets->lock lock];
905
906 [dev->rsets->data addObject:rset];
907
908 [dev->rsets->lock unlock];
909}
910
911void ggml_metal_device_rsets_rm(ggml_metal_device_t dev, ggml_metal_rset_t rset) {
912 if (rset == nil) {
913 return;
914 }
915
916 GGML_ASSERT(dev->rsets);
917
918 [dev->rsets->lock lock];
919
920 [dev->rsets->data removeObject:rset];
921
922 [dev->rsets->lock unlock];
923}
924
925void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) {
926 if (dev->rsets == NULL) {
927 return;
928 }
929
930 atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed);
931}
932
933struct ggml_metal_event {
934 void * obj; // id<MTLEvent>
935
936 atomic_int value;
937};
938
939void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {
940 id<MTLEvent> event = (id<MTLEvent>)ev->obj;
941
942 id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw;
943
944 [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1];
945}
946
947void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) {
948 id<MTLEvent> event = (id<MTLEvent>)ev->obj;
949
950 id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw;
951
952 [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];
953}
954
955ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) {
956 id<MTLEvent> event = [dev->mtl_device newEvent];
957
958 ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event));
959
960 ev->obj = (__bridge void *)event;
961 ev->value = 0;
962
963 return ev;
964}
965
966void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) {
967 id<MTLEvent> event = ev->obj;
968 [event release];
969
970 free(ev);
971
972 GGML_UNUSED(dev);
973}
974
975void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) {
976 @autoreleasepool {
977 id<MTLEvent> event = ev->obj;
978
979 id<MTLCommandBuffer> cmd_buf = [dev->mtl_queue commandBuffer];
980 [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)];
981 [cmd_buf commit];
982 [cmd_buf waitUntilCompleted];
983 }
984}
985
986void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) {
987 if (@available(macOS 10.12, iOS 16.0, *)) {
988 *total = dev->mtl_device.recommendedMaxWorkingSetSize;
989 *free = *total - dev->mtl_device.currentAllocatedSize;
990 } else {
991 *free = 0;
992 *total = 0;
993 }
994}
995
996bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op) {
997 const bool has_simdgroup_mm = dev->props.has_simdgroup_mm;
998 const bool has_simdgroup_reduction = dev->props.has_simdgroup_reduction;
999 const bool has_bfloat = dev->props.has_bfloat;
1000
1001 if (!has_bfloat) {
1002 if (op->type == GGML_TYPE_BF16) {
1003 return false;
1004 }
1005
1006 for (size_t i = 0, n = 3; i < n; ++i) {
1007 if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1008 return false;
1009 }
1010 }
1011 }
1012
1013 switch (op->op) {
1014 case GGML_OP_SCALE:
1015 case GGML_OP_FILL:
1016 case GGML_OP_CLAMP:
1017 case GGML_OP_SQR:
1018 case GGML_OP_SQRT:
1019 case GGML_OP_SIN:
1020 case GGML_OP_COS:
1021 case GGML_OP_LOG:
1022 return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
1023 case GGML_OP_UNARY:
1024 switch (ggml_get_unary_op(op)) {
1025 case GGML_UNARY_OP_TANH:
1026 case GGML_UNARY_OP_RELU:
1027 case GGML_UNARY_OP_SIGMOID:
1028 case GGML_UNARY_OP_GELU:
1029 case GGML_UNARY_OP_GELU_ERF:
1030 case GGML_UNARY_OP_GELU_QUICK:
1031 case GGML_UNARY_OP_SILU:
1032 case GGML_UNARY_OP_ELU:
1033 case GGML_UNARY_OP_NEG:
1034 case GGML_UNARY_OP_ABS:
1035 case GGML_UNARY_OP_SGN:
1036 case GGML_UNARY_OP_STEP:
1037 case GGML_UNARY_OP_HARDSWISH:
1038 case GGML_UNARY_OP_HARDSIGMOID:
1039 case GGML_UNARY_OP_EXP:
1040 case GGML_UNARY_OP_SOFTPLUS:
1041 case GGML_UNARY_OP_EXPM1:
1042 return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
1043 default:
1044 return false;
1045 }
1046 case GGML_OP_GLU:
1047 switch (ggml_get_glu_op(op)) {
1048 case GGML_GLU_OP_REGLU:
1049 case GGML_GLU_OP_GEGLU:
1050 case GGML_GLU_OP_SWIGLU:
1051 case GGML_GLU_OP_SWIGLU_OAI:
1052 case GGML_GLU_OP_GEGLU_ERF:
1053 case GGML_GLU_OP_GEGLU_QUICK:
1054 return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1055 default:
1056 return false;
1057 }
1058 case GGML_OP_NONE:
1059 case GGML_OP_RESHAPE:
1060 case GGML_OP_VIEW:
1061 case GGML_OP_TRANSPOSE:
1062 case GGML_OP_PERMUTE:
1063 case GGML_OP_CONCAT:
1064 return true;
1065 case GGML_OP_ADD:
1066 case GGML_OP_SUB:
1067 case GGML_OP_MUL:
1068 case GGML_OP_DIV:
1069 case GGML_OP_ADD_ID:
1070 return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
1071 case GGML_OP_ACC:
1072 case GGML_OP_REPEAT:
1073 case GGML_OP_CONV_TRANSPOSE_1D:
1074 return true;
1075 case GGML_OP_CONV_TRANSPOSE_2D:
1076 return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) &&
1077 (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
1078 op->src[1]->type == GGML_TYPE_F32 &&
1079 op->type == GGML_TYPE_F32;
1080 case GGML_OP_SUM:
1081 return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
1082 case GGML_OP_TRI:
1083 return ggml_is_contiguous_rows(op->src[0]);
1084 case GGML_OP_SUM_ROWS:
1085 case GGML_OP_CUMSUM:
1086 case GGML_OP_MEAN:
1087 case GGML_OP_SOFT_MAX:
1088 case GGML_OP_GROUP_NORM:
1089 case GGML_OP_L2_NORM:
1090 return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
1091 case GGML_OP_COUNT_EQUAL:
1092 return has_simdgroup_reduction &&
1093 op->src[0]->type == GGML_TYPE_I32 &&
1094 op->src[1]->type == GGML_TYPE_I32 &&
1095 op->type == GGML_TYPE_I64;
1096 case GGML_OP_ARGMAX:
1097 return has_simdgroup_reduction;
1098 case GGML_OP_NORM:
1099 case GGML_OP_RMS_NORM:
1100 return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
1101 case GGML_OP_ROPE:
1102 return true;
1103 case GGML_OP_IM2COL:
1104 return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
1105 case GGML_OP_CONV_2D:
1106 return ggml_is_contiguous(op->src[0]) &&
1107 op->src[1]->type == GGML_TYPE_F32 &&
1108 op->type == GGML_TYPE_F32 &&
1109 (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
1110 case GGML_OP_UPSCALE:
1111 return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
1112 case GGML_OP_POOL_1D:
1113 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1114 case GGML_OP_POOL_2D:
1115 return op->src[0]->type == GGML_TYPE_F32;
1116 case GGML_OP_PAD:
1117 // TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985
1118 if (ggml_get_op_params_i32(op, 8) != 0) {
1119 return false;
1120 }
1121
1122 return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
1123 (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
1124 case GGML_OP_PAD_REFLECT_1D:
1125 case GGML_OP_TIMESTEP_EMBEDDING:
1126 case GGML_OP_LEAKY_RELU:
1127 return op->src[0]->type == GGML_TYPE_F32;
1128 case GGML_OP_ARGSORT:
1129 case GGML_OP_TOP_K:
1130 case GGML_OP_ARANGE:
1131 return true;
1132 case GGML_OP_FLASH_ATTN_EXT:
1133 // for new head sizes, add checks here
1134 if (op->src[0]->ne[0] != 32 &&
1135 op->src[0]->ne[0] != 40 &&
1136 op->src[0]->ne[0] != 48 &&
1137 op->src[0]->ne[0] != 64 &&
1138 op->src[0]->ne[0] != 72 &&
1139 op->src[0]->ne[0] != 80 &&
1140 op->src[0]->ne[0] != 96 &&
1141 op->src[0]->ne[0] != 112 &&
1142 op->src[0]->ne[0] != 128 &&
1143 op->src[0]->ne[0] != 192 &&
1144 op->src[0]->ne[0] != 256 &&
1145 op->src[0]->ne[0] != 576) {
1146 return false;
1147 }
1148 if (op->src[1]->type != op->src[2]->type) {
1149 return false;
1150 }
1151 return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
1152 case GGML_OP_SSM_CONV:
1153 case GGML_OP_SSM_SCAN:
1154 return has_simdgroup_reduction;
1155 case GGML_OP_RWKV_WKV6:
1156 case GGML_OP_RWKV_WKV7:
1157 return true;
1158 case GGML_OP_SOLVE_TRI:
1159 case GGML_OP_MUL_MAT:
1160 case GGML_OP_MUL_MAT_ID:
1161 return has_simdgroup_reduction;
1162 case GGML_OP_CPY:
1163 case GGML_OP_DUP:
1164 case GGML_OP_CONT:
1165 {
1166 switch (op->src[0]->type) {
1167 case GGML_TYPE_F32:
1168 switch (op->type) {
1169 case GGML_TYPE_F32:
1170 case GGML_TYPE_F16:
1171 case GGML_TYPE_BF16:
1172 case GGML_TYPE_Q8_0:
1173 case GGML_TYPE_Q4_0:
1174 case GGML_TYPE_Q4_1:
1175 case GGML_TYPE_Q5_0:
1176 case GGML_TYPE_Q5_1:
1177 case GGML_TYPE_IQ4_NL:
1178 case GGML_TYPE_I32:
1179 return true;
1180 default:
1181 return false;
1182 }
1183 case GGML_TYPE_F16:
1184 switch (op->type) {
1185 case GGML_TYPE_F32:
1186 case GGML_TYPE_F16:
1187 return true;
1188 default:
1189 return false;
1190 }
1191 case GGML_TYPE_BF16:
1192 switch (op->type) {
1193 case GGML_TYPE_F32:
1194 case GGML_TYPE_BF16:
1195 return true;
1196 default:
1197 return false;
1198 }
1199 case GGML_TYPE_Q4_0:
1200 case GGML_TYPE_Q4_1:
1201 case GGML_TYPE_Q5_0:
1202 case GGML_TYPE_Q5_1:
1203 case GGML_TYPE_Q8_0:
1204 switch (op->type) {
1205 case GGML_TYPE_F32:
1206 case GGML_TYPE_F16:
1207 return true;
1208 default:
1209 return false;
1210 }
1211 case GGML_TYPE_I32:
1212 return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32;
1213 default:
1214 return false;
1215 };
1216 }
1217 case GGML_OP_GET_ROWS:
1218 return true;
1219 case GGML_OP_SET_ROWS:
1220 {
1221 if (op->src[0]->type != GGML_TYPE_F32) {
1222 return false;
1223 }
1224
1225 switch (op->type) {
1226 case GGML_TYPE_F32:
1227 case GGML_TYPE_F16:
1228 case GGML_TYPE_BF16:
1229 case GGML_TYPE_Q8_0:
1230 case GGML_TYPE_Q4_0:
1231 case GGML_TYPE_Q4_1:
1232 case GGML_TYPE_Q5_0:
1233 case GGML_TYPE_Q5_1:
1234 case GGML_TYPE_IQ4_NL:
1235 return true;
1236 default:
1237 return false;
1238 };
1239 }
1240 case GGML_OP_DIAG:
1241 return true;
1242 case GGML_OP_OPT_STEP_ADAMW:
1243 case GGML_OP_OPT_STEP_SGD:
1244 return has_simdgroup_reduction;
1245 default:
1246 return false;
1247 }
1248}
1249
1250const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev) {
1251 return &dev->props;
1252}
1253
1254//
1255// device buffers
1256//
1257
1258// max memory buffers that can be mapped to the device
1259#define GGML_METAL_MAX_BUFFERS 64
1260
1261struct ggml_metal_buffer_wrapper {
1262 void * data;
1263 size_t size;
1264
1265 id<MTLBuffer> metal;
1266};
1267
1268struct ggml_metal_buffer {
1269 void * all_data;
1270 size_t all_size;
1271
1272 // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
1273 bool is_shared;
1274 bool owned;
1275
1276 // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap
1277 int n_buffers;
1278 struct ggml_metal_buffer_wrapper buffers[GGML_METAL_MAX_BUFFERS];
1279
1280 bool use_residency_sets;
1281
1282 // optional MTLResidencySet
1283 // note: cannot use explicity "id<MTLResidencySet>" here because it is not available on certain OSes
1284 id rset;
1285
1286 // pointers to global device
1287 ggml_metal_device_t dev;
1288};
1289
1290static void ggml_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
1291#ifndef GGML_METAL_NDEBUG
1292#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
1293 if (@available(macOS 10.12, iOS 16.0, *)) {
1294 GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n",
1295 __func__,
1296 size_aligned / 1024.0 / 1024.0,
1297 device.currentAllocatedSize / 1024.0 / 1024.0,
1298 device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
1299
1300 if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) {
1301 GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__);
1302 }
1303 } else {
1304 GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
1305 __func__,
1306 size_aligned / 1024.0 / 1024.0,
1307 device.currentAllocatedSize / 1024.0 / 1024.0);
1308 }
1309#endif
1310#endif
1311 GGML_UNUSED(device);
1312 GGML_UNUSED(size_aligned);
1313}
1314
1315// rset init
1316static bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) {
1317 buf->rset = nil;
1318
1319 if (!buf->use_residency_sets) {
1320 return true;
1321 }
1322
1323#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
1324 if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1325 MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
1326 desc.label = @"ggml_metal";
1327 desc.initialCapacity = buf->n_buffers;
1328
1329 NSError * error;
1330 buf->rset = [buf->dev->mtl_device newResidencySetWithDescriptor:desc error:&error];
1331 if (error) {
1332 GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
1333 [desc release];
1334 return false;
1335 }
1336
1337 [desc release];
1338
1339 for (int i = 0; i < buf->n_buffers; i++) {
1340 [buf->rset addAllocation:buf->buffers[i].metal];
1341 }
1342
1343 [buf->rset commit];
1344 [buf->rset requestResidency];
1345
1346 return true;
1347 }
1348#endif
1349
1350 return true;
1351}
1352
1353// rset free
1354static void ggml_metal_buffer_rset_free(ggml_metal_buffer_t buf) {
1355#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
1356 if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
1357 if (buf->rset) {
1358 [buf->rset endResidency];
1359 [buf->rset removeAllAllocations];
1360 [buf->rset release];
1361 }
1362 }
1363#else
1364 GGML_UNUSED(buf);
1365#endif
1366}
1367
1368static void * ggml_metal_host_malloc(size_t n) {
1369 void * data = NULL;
1370
1371#if TARGET_OS_OSX
1372 kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE);
1373 if (err != KERN_SUCCESS) {
1374 GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__);
1375 return NULL;
1376 }
1377#else
1378 const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n);
1379 if (result != 0) {
1380 GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__);
1381 return NULL;
1382 }
1383#endif
1384
1385 return data;
1386}
1387
1388ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) {
1389 ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
1390
1391 res->dev = dev;
1392
1393 const size_t size_page = sysconf(_SC_PAGESIZE);
1394
1395 size_t size_aligned = size;
1396 if ((size_aligned % size_page) != 0) {
1397 size_aligned += (size_page - (size_aligned % size_page));
1398 }
1399
1400 const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
1401
1402 shared = shared && props_dev->use_shared_buffers;
1403
1404 // allocate shared buffer if the device supports it and it is required by the buffer type
1405 if (shared) {
1406 res->all_data = ggml_metal_host_malloc(size_aligned);
1407 res->is_shared = true;
1408 } else {
1409 // use virtual address
1410 res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed);
1411 res->is_shared = false;
1412 }
1413 res->all_size = size_aligned;
1414
1415 res->owned = true;
1416
1417 res->n_buffers = 1;
1418
1419 if (res->all_data != NULL) {
1420 res->buffers[0].size = size;
1421 res->buffers[0].metal = nil;
1422
1423 if (size_aligned > 0) {
1424 if (props_dev->use_shared_buffers && shared) {
1425 res->buffers[0].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:res->all_data
1426 length:size_aligned
1427 options:MTLResourceStorageModeShared
1428 deallocator:nil];
1429 } else {
1430 res->buffers[0].metal = [res->dev->mtl_device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate];
1431 }
1432 }
1433
1434 res->buffers[0].data = res->all_data;
1435 }
1436
1437 if (size_aligned > 0 && (res->all_data == NULL || res->buffers[0].metal == nil)) {
1438 GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
1439 free(res);
1440 return NULL;
1441 }
1442
1443 res->use_residency_sets = props_dev->use_residency_sets;
1444
1445 if (!ggml_metal_buffer_rset_init(res)) {
1446 GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
1447 free(res);
1448 return NULL;
1449 }
1450
1451 ggml_metal_device_rsets_add(dev, res->rset);
1452
1453 //ggml_metal_log_allocated_size(device, size_aligned);
1454
1455 return res;
1456}
1457
1458ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) {
1459 ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer));
1460
1461 res->dev = dev;
1462
1463 res->all_data = ptr;
1464 res->all_size = size;
1465
1466 res->is_shared = true;
1467 res->owned = false;
1468
1469 res->n_buffers = 0;
1470
1471 const size_t size_page = sysconf(_SC_PAGESIZE);
1472
1473 // page-align the data ptr
1474 {
1475 const uintptr_t offs = (uintptr_t) ptr % size_page;
1476 ptr = (void *) ((char *) ptr - offs);
1477 size += offs;
1478 }
1479
1480 size_t size_aligned = size;
1481 if ((size_aligned % size_page) != 0) {
1482 size_aligned += (size_page - (size_aligned % size_page));
1483 }
1484
1485 const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
1486
1487 // the buffer fits into the max buffer size allowed by the device
1488 if (size_aligned <= props_dev->max_buffer_size) {
1489 res->buffers[res->n_buffers].data = ptr;
1490 res->buffers[res->n_buffers].size = size;
1491 res->buffers[res->n_buffers].metal = nil;
1492
1493 if (size_aligned > 0) {
1494 res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
1495
1496 if (res->buffers[res->n_buffers].metal == nil) {
1497 GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
1498 free(res);
1499 return NULL;
1500 }
1501 }
1502
1503 ggml_metal_log_allocated_size(res->dev->mtl_device, size_aligned);
1504
1505 ++res->n_buffers;
1506 } else {
1507 // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
1508 // one of the views
1509 const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
1510 const size_t size_step = props_dev->max_buffer_size - size_ovlp;
1511 const size_t size_view = props_dev->max_buffer_size;
1512
1513 for (size_t i = 0; i < size; i += size_step) {
1514 const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
1515
1516 res->buffers[res->n_buffers].data = (void *) ((uint8_t *) ptr + i);
1517 res->buffers[res->n_buffers].size = size_step_aligned;
1518 res->buffers[res->n_buffers].metal = nil;
1519
1520 if (size_step_aligned > 0) {
1521 res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
1522
1523 if (res->buffers[res->n_buffers].metal == nil) {
1524 GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
1525 free(res);
1526 return NULL;
1527 }
1528 }
1529
1530 ggml_metal_log_allocated_size(res->dev->mtl_device, size_step_aligned);
1531
1532 if (i + size_step < size) {
1533 GGML_LOG_INFO("\n");
1534 }
1535
1536 ++res->n_buffers;
1537 }
1538 }
1539
1540 res->use_residency_sets = props_dev->use_residency_sets;
1541
1542 if (!ggml_metal_buffer_rset_init(res)) {
1543 GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__);
1544 free(res);
1545 return NULL;
1546 }
1547
1548 ggml_metal_device_rsets_add(dev, res->rset);
1549
1550 return res;
1551}
1552
1553void ggml_metal_buffer_free(ggml_metal_buffer_t buf) {
1554 ggml_metal_device_rsets_rm(buf->dev, buf->rset);
1555
1556 for (int i = 0; i < buf->n_buffers; i++) {
1557 [buf->buffers[i].metal release];
1558 }
1559
1560 ggml_metal_buffer_rset_free(buf);
1561
1562 if (buf->is_shared && buf->owned) {
1563#if TARGET_OS_OSX
1564 vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size);
1565#else
1566 free(buf->all_data);
1567#endif
1568 }
1569
1570 free(buf);
1571}
1572
1573void * ggml_metal_buffer_get_base(ggml_metal_buffer_t buf) {
1574 return buf->all_data;
1575}
1576
1577bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
1578 return buf->is_shared;
1579}
1580
1581void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
1582 if (buf->is_shared) {
1583 memset((char *) tensor->data + offset, value, size);
1584 return;
1585 }
1586
1587 @autoreleasepool {
1588 // dst
1589 struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor);
1590 bid_dst.offs += offset;
1591
1592 id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1593
1594 {
1595 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
1596
1597 [encoder fillBuffer:bid_dst.metal
1598 range:NSMakeRange(bid_dst.offs, bid_dst.offs + size)
1599 value:value];
1600
1601 [encoder endEncoding];
1602 }
1603
1604 [cmd_buf commit];
1605 [cmd_buf waitUntilCompleted];
1606 }
1607}
1608
1609void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
1610 if (buf->is_shared) {
1611 memcpy((char *) tensor->data + offset, data, size);
1612 return;
1613 }
1614
1615 @autoreleasepool {
1616 // src
1617 void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data
1618 id<MTLBuffer> buf_src = [buf->dev->mtl_device newBufferWithBytesNoCopy:data_ptr
1619 length:size
1620 options:MTLResourceStorageModeShared
1621 deallocator:nil];
1622
1623 GGML_ASSERT(buf_src);
1624
1625 // dst
1626 struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor);
1627 bid_dst.offs += offset;
1628
1629 // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete
1630 // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference
1631 dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0);
1632
1633 id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1634
1635 {
1636 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
1637
1638 [encoder copyFromBuffer:buf_src
1639 sourceOffset:0
1640 toBuffer:bid_dst.metal
1641 destinationOffset:bid_dst.offs
1642 size:size];
1643
1644 [encoder endEncoding];
1645 }
1646
1647 [cmd_buf addCompletedHandler:^(id<MTLCommandBuffer> cb) {
1648 // TODO: can check for errors here
1649 GGML_UNUSED(cb);
1650
1651 dispatch_semaphore_signal(completion_semaphore);
1652 }];
1653
1654 [cmd_buf commit];
1655
1656 dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER);
1657 dispatch_release(completion_semaphore);
1658
1659 //[cmd_buf waitUntilCompleted];
1660 }
1661}
1662
1663void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
1664 if (buf->is_shared) {
1665 memcpy(data, (const char *) tensor->data + offset, size);
1666 return;
1667 }
1668
1669 @autoreleasepool {
1670 // src
1671 struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf, tensor);
1672 bid_src.offs += offset;
1673
1674 // dst
1675 id<MTLBuffer> buf_dst = [buf->dev->mtl_device newBufferWithBytesNoCopy:data
1676 length:size
1677 options:MTLResourceStorageModeShared
1678 deallocator:nil];
1679
1680 GGML_ASSERT(buf_dst);
1681
1682 id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1683
1684 {
1685 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
1686
1687 [encoder copyFromBuffer:bid_src.metal
1688 sourceOffset:bid_src.offs
1689 toBuffer:buf_dst
1690 destinationOffset:0
1691 size:size];
1692
1693 [encoder endEncoding];
1694 }
1695
1696 [cmd_buf commit];
1697 [cmd_buf waitUntilCompleted];
1698 }
1699}
1700
1701void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) {
1702 if (buf->is_shared) {
1703 memset(buf->all_data, value, buf->all_size);
1704 return;
1705 }
1706
1707 @autoreleasepool {
1708 id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences];
1709
1710 {
1711 id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
1712
1713 [encoder fillBuffer:buf->buffers[0].metal
1714 range:NSMakeRange(0, buf->buffers[0].size)
1715 value:value];
1716
1717 [encoder endEncoding];
1718 }
1719
1720 [cmd_buf commit];
1721 [cmd_buf waitUntilCompleted];
1722 }
1723}
1724
1725struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t) {
1726 struct ggml_metal_buffer_id res = { nil, 0 };
1727
1728 const int64_t tsize = ggml_nbytes(t);
1729
1730 // find the view that contains the tensor fully
1731 for (int i = 0; i < buf->n_buffers; ++i) {
1732 const int64_t ioffs = (int64_t) t->data - (int64_t) buf->buffers[i].data;
1733
1734 //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf->buffers[i].size);
1735 if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf->buffers[i].size) {
1736 res.metal = buf->buffers[i].metal;
1737 res.offs = (size_t) ioffs;
1738
1739 //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs);
1740
1741 return res;
1742 }
1743 }
1744
1745 GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name);
1746
1747 return res;
1748}
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h
new file mode 100644
index 0000000..952e1be
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -0,0 +1,1051 @@
1#ifndef GGML_METAL_IMPL
2#define GGML_METAL_IMPL
3
4// kernel parameters for mat-vec threadgroups
5//
6// N_R0: number of src0 rows to process per simdgroup
7// N_SG: number of simdgroups per threadgroup
8//
9// TODO: for optimal performance, become function of the device and work size
10
11#define N_R0_Q4_0 4
12#define N_SG_Q4_0 2
13
14#define N_R0_Q4_1 4
15#define N_SG_Q4_1 2
16
17#define N_R0_Q5_0 4
18#define N_SG_Q5_0 2
19
20#define N_R0_Q5_1 4
21#define N_SG_Q5_1 2
22
23#define N_R0_Q8_0 2
24#define N_SG_Q8_0 4
25
26#define N_R0_MXFP4 2
27#define N_SG_MXFP4 2
28
29#define N_R0_Q2_K 4
30#define N_SG_Q2_K 2
31
32#define N_R0_Q3_K 2
33#define N_SG_Q3_K 2
34
35#define N_R0_Q4_K 2
36#define N_SG_Q4_K 2
37
38#define N_R0_Q5_K 2
39#define N_SG_Q5_K 2
40
41#define N_R0_Q6_K 2
42#define N_SG_Q6_K 2
43
44#define N_R0_IQ1_S 4
45#define N_SG_IQ1_S 2
46
47#define N_R0_IQ1_M 4
48#define N_SG_IQ1_M 2
49
50#define N_R0_IQ2_XXS 4
51#define N_SG_IQ2_XXS 2
52
53#define N_R0_IQ2_XS 4
54#define N_SG_IQ2_XS 2
55
56#define N_R0_IQ2_S 4
57#define N_SG_IQ2_S 2
58
59#define N_R0_IQ3_XXS 4
60#define N_SG_IQ3_XXS 2
61
62#define N_R0_IQ3_S 4
63#define N_SG_IQ3_S 2
64
65#define N_R0_IQ4_NL 2
66#define N_SG_IQ4_NL 2
67
68#define N_R0_IQ4_XS 2
69#define N_SG_IQ4_XS 2
70
71// function constants offsets
72#define FC_FLASH_ATTN_EXT_PAD 100
73#define FC_FLASH_ATTN_EXT_BLK 200
74#define FC_FLASH_ATTN_EXT 300
75#define FC_FLASH_ATTN_EXT_VEC 400
76#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500
77#define FC_MUL_MV 600
78#define FC_MUL_MM 700
79#define FC_ROPE 800
80#define FC_SSM_CONV 900
81#define FC_SOLVE_TRI 1000
82#define FC_COUNT_EQUAL 1100
83#define FC_UNARY 1200
84#define FC_BIN 1300
85
86// op-specific constants
87#define OP_FLASH_ATTN_EXT_NQPSG 8
88#define OP_FLASH_ATTN_EXT_NCPSG 64
89
90#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
91#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
92
93#define OP_UNARY_NUM_SCALE 10
94#define OP_UNARY_NUM_FILL 11
95#define OP_UNARY_NUM_CLAMP 12
96#define OP_UNARY_NUM_SQR 13
97#define OP_UNARY_NUM_SQRT 14
98#define OP_UNARY_NUM_SIN 15
99#define OP_UNARY_NUM_COS 16
100#define OP_UNARY_NUM_LOG 17
101#define OP_UNARY_NUM_LEAKY_RELU 18
102
103#define OP_UNARY_NUM_TANH 100
104#define OP_UNARY_NUM_RELU 101
105#define OP_UNARY_NUM_SIGMOID 102
106#define OP_UNARY_NUM_GELU 103
107#define OP_UNARY_NUM_GELU_ERF 104
108#define OP_UNARY_NUM_GELU_QUICK 105
109#define OP_UNARY_NUM_SILU 106
110#define OP_UNARY_NUM_ELU 107
111#define OP_UNARY_NUM_NEG 108
112#define OP_UNARY_NUM_ABS 109
113#define OP_UNARY_NUM_SGN 110
114#define OP_UNARY_NUM_STEP 111
115#define OP_UNARY_NUM_HARDSWISH 112
116#define OP_UNARY_NUM_HARDSIGMOID 113
117#define OP_UNARY_NUM_EXP 114
118#define OP_UNARY_NUM_SOFTPLUS 115
119#define OP_UNARY_NUM_EXPM1 116
120
121
122// kernel argument structs
123//
124// - element counters (e.g. ne00) typically use int32_t to reduce register usage
125// however, be careful from int overflows when using those in the kernel implementation
126//
127// - strides (e.g. nb00) use uint64_t
128
129typedef struct {
130 int32_t ne00;
131 int32_t ne01;
132 int32_t ne02;
133 int32_t ne03;
134 uint64_t nb00;
135 uint64_t nb01;
136 uint64_t nb02;
137 uint64_t nb03;
138 int32_t ne10;
139 int32_t ne11;
140 int32_t ne12;
141 int32_t ne13;
142 uint64_t nb10;
143 uint64_t nb11;
144 uint64_t nb12;
145 uint64_t nb13;
146 int32_t ne0;
147 int32_t ne1;
148 int32_t ne2;
149 int32_t ne3;
150 uint64_t nb0;
151 uint64_t nb1;
152 uint64_t nb2;
153 uint64_t nb3;
154 int32_t dim;
155} ggml_metal_kargs_concat;
156
157typedef struct {
158 int32_t ne00;
159 int32_t ne01;
160 int32_t ne02;
161 int32_t ne03;
162 uint64_t nb00;
163 uint64_t nb01;
164 uint64_t nb02;
165 uint64_t nb03;
166 int32_t ne0;
167 int32_t ne1;
168 int32_t ne2;
169 int32_t ne3;
170 uint64_t nb0;
171 uint64_t nb1;
172 uint64_t nb2;
173 uint64_t nb3;
174 float slope;
175 float scale;
176 float bias;
177 float val;
178 float min;
179 float max;
180} ggml_metal_kargs_unary;
181
182typedef struct {
183 int32_t ne00;
184 int32_t ne01;
185 int32_t ne02;
186 int32_t ne03;
187 uint64_t nb00;
188 uint64_t nb01;
189 uint64_t nb02;
190 uint64_t nb03;
191 int32_t ne10;
192 int32_t ne11;
193 int32_t ne12;
194 int32_t ne13;
195 uint64_t nb10;
196 uint64_t nb11;
197 uint64_t nb12;
198 uint64_t nb13;
199 int32_t ne0;
200 int32_t ne1;
201 int32_t ne2;
202 int32_t ne3;
203 uint64_t nb0;
204 uint64_t nb1;
205 uint64_t nb2;
206 uint64_t nb3;
207 uint64_t offs;
208 uint64_t o1[8];
209} ggml_metal_kargs_bin;
210
211typedef struct {
212 int64_t ne0;
213 int64_t ne1;
214 size_t nb01;
215 size_t nb02;
216 size_t nb11;
217 size_t nb21;
218} ggml_metal_kargs_add_id;
219
220typedef struct {
221 int32_t ne00;
222 int32_t ne01;
223 int32_t ne02;
224 int32_t ne03;
225 uint64_t nb00;
226 uint64_t nb01;
227 uint64_t nb02;
228 uint64_t nb03;
229 int32_t ne0;
230 int32_t ne1;
231 int32_t ne2;
232 int32_t ne3;
233 uint64_t nb0;
234 uint64_t nb1;
235 uint64_t nb2;
236 uint64_t nb3;
237} ggml_metal_kargs_repeat;
238
239typedef struct {
240 int64_t nk0;
241 int64_t ne00;
242 int64_t ne01;
243 int64_t ne02;
244 int64_t ne03;
245 uint64_t nb00;
246 uint64_t nb01;
247 uint64_t nb02;
248 uint64_t nb03;
249 int64_t ne0;
250 int64_t ne1;
251 int64_t ne2;
252 int64_t ne3;
253 uint64_t nb0;
254 uint64_t nb1;
255 uint64_t nb2;
256 uint64_t nb3;
257} ggml_metal_kargs_cpy;
258
259typedef struct {
260 int64_t ne10;
261 int64_t ne11;
262 int64_t ne12;
263 uint64_t nb10;
264 uint64_t nb11;
265 uint64_t nb12;
266 uint64_t nb13;
267 uint64_t nb1;
268 uint64_t nb2;
269 uint64_t nb3;
270 uint64_t offs;
271 bool inplace;
272} ggml_metal_kargs_set;
273
274typedef struct {
275 int32_t ne00;
276 int32_t ne01;
277 int32_t ne02;
278 int32_t ne03;
279 uint64_t nb00;
280 uint64_t nb01;
281 uint64_t nb02;
282 uint64_t nb03;
283 int32_t ne0;
284 int32_t ne1;
285 int32_t ne2;
286 int32_t ne3;
287 uint64_t nb0;
288 uint64_t nb1;
289 uint64_t nb2;
290 uint64_t nb3;
291 int32_t n_past;
292 int32_t n_dims;
293 int32_t n_ctx_orig;
294 float freq_base;
295 float freq_scale;
296 float ext_factor;
297 float attn_factor;
298 float beta_fast;
299 float beta_slow;
300 int32_t sect_0;
301 int32_t sect_1;
302 int32_t sect_2;
303 int32_t sect_3;
304 bool src2;
305} ggml_metal_kargs_rope;
306
307typedef struct {
308 int32_t ne11;
309 int32_t ne_12_2; // assume K and V are same shape
310 int32_t ne_12_3;
311 uint64_t nb11;
312 uint64_t nb12;
313 uint64_t nb13;
314 uint64_t nb21;
315 uint64_t nb22;
316 uint64_t nb23;
317 int32_t ne31;
318 int32_t ne32;
319 int32_t ne33;
320 uint64_t nb31;
321 uint64_t nb32;
322 uint64_t nb33;
323} ggml_metal_kargs_flash_attn_ext_pad;
324
325typedef struct {
326 int32_t ne01;
327 int32_t ne30;
328 int32_t ne31;
329 int32_t ne32;
330 int32_t ne33;
331 uint64_t nb31;
332 uint64_t nb32;
333 uint64_t nb33;
334} ggml_metal_kargs_flash_attn_ext_blk;
335
336typedef struct {
337 int32_t ne01;
338 int32_t ne02;
339 int32_t ne03;
340 uint64_t nb01;
341 uint64_t nb02;
342 uint64_t nb03;
343 int32_t ne11;
344 int32_t ne_12_2; // assume K and V are same shape
345 int32_t ne_12_3;
346 int32_t ns10;
347 uint64_t nb11;
348 uint64_t nb12;
349 uint64_t nb13;
350 int32_t ns20;
351 uint64_t nb21;
352 uint64_t nb22;
353 uint64_t nb23;
354 int32_t ne31;
355 int32_t ne32;
356 int32_t ne33;
357 uint64_t nb31;
358 uint64_t nb32;
359 uint64_t nb33;
360 int32_t ne1;
361 int32_t ne2;
362 int32_t ne3;
363 float scale;
364 float max_bias;
365 float m0;
366 float m1;
367 int32_t n_head_log2;
368 float logit_softcap;
369} ggml_metal_kargs_flash_attn_ext;
370
371typedef struct {
372 int32_t ne01;
373 int32_t ne02;
374 int32_t ne03;
375 uint64_t nb01;
376 uint64_t nb02;
377 uint64_t nb03;
378 int32_t ne11;
379 int32_t ne_12_2; // assume K and V are same shape
380 int32_t ne_12_3;
381 int32_t ns10;
382 uint64_t nb11;
383 uint64_t nb12;
384 uint64_t nb13;
385 int32_t ns20;
386 uint64_t nb21;
387 uint64_t nb22;
388 uint64_t nb23;
389 int32_t ne31;
390 int32_t ne32;
391 int32_t ne33;
392 uint64_t nb31;
393 uint64_t nb32;
394 uint64_t nb33;
395 int32_t ne1;
396 int32_t ne2;
397 int32_t ne3;
398 float scale;
399 float max_bias;
400 float m0;
401 float m1;
402 int32_t n_head_log2;
403 float logit_softcap;
404} ggml_metal_kargs_flash_attn_ext_vec;
405
406typedef struct {
407 int32_t nrows;
408} ggml_metal_kargs_flash_attn_ext_vec_reduce;
409
410typedef struct {
411 int32_t ne00;
412 int32_t ne02;
413 uint64_t nb01;
414 uint64_t nb02;
415 uint64_t nb03;
416 int32_t ne12;
417 uint64_t nb10;
418 uint64_t nb11;
419 uint64_t nb12;
420 uint64_t nb13;
421 int32_t ne0;
422 int32_t ne1;
423 int16_t r2;
424 int16_t r3;
425} ggml_metal_kargs_mul_mm;
426
427typedef struct {
428 int32_t ne00;
429 int32_t ne01;
430 int32_t ne02;
431 uint64_t nb00;
432 uint64_t nb01;
433 uint64_t nb02;
434 uint64_t nb03;
435 int32_t ne10;
436 int32_t ne11;
437 int32_t ne12;
438 uint64_t nb10;
439 uint64_t nb11;
440 uint64_t nb12;
441 uint64_t nb13;
442 int32_t ne0;
443 int32_t ne1;
444 int32_t nr0;
445 int16_t r2;
446 int16_t r3;
447} ggml_metal_kargs_mul_mv;
448
449typedef struct {
450 int32_t ne00;
451 int32_t ne01;
452 int32_t ne02;
453 uint64_t nb00;
454 uint64_t nb01;
455 uint64_t nb02;
456 uint64_t nb03;
457 int32_t ne10;
458 int32_t ne11;
459 int32_t ne12;
460 uint64_t nb10;
461 uint64_t nb11;
462 uint64_t nb12;
463 uint64_t nb13;
464 int32_t ne0;
465 int32_t ne1;
466 int16_t r2;
467 int16_t r3;
468} ggml_metal_kargs_mul_mv_ext;
469
470typedef struct {
471 int32_t ne02;
472 int32_t ne10;
473 int32_t ne11; // n_expert_used (bcast)
474 uint64_t nb11;
475 uint64_t nb12;
476 int32_t ne21; // n_tokens
477 int32_t ne20; // n_expert_used
478 uint64_t nb21;
479} ggml_metal_kargs_mul_mm_id_map0;
480
481typedef struct {
482 int32_t ne00;
483 int32_t ne02;
484 uint64_t nb01;
485 uint64_t nb02;
486 uint64_t nb03;
487 int32_t ne11;
488 uint64_t nb10;
489 uint64_t nb11;
490 uint64_t nb12;
491 uint64_t nb13;
492 int32_t ne20;
493 int32_t ne21;
494 int32_t ne0;
495 int32_t ne1;
496 int16_t r2;
497 int16_t r3;
498} ggml_metal_kargs_mul_mm_id;
499
500typedef struct {
501 int32_t nei0;
502 int32_t nei1;
503 uint64_t nbi1;
504 int32_t ne00;
505 int32_t ne01;
506 int32_t ne02;
507 uint64_t nb00;
508 uint64_t nb01;
509 uint64_t nb02;
510 int32_t ne10;
511 int32_t ne11;
512 int32_t ne12;
513 int32_t ne13;
514 uint64_t nb10;
515 uint64_t nb11;
516 uint64_t nb12;
517 int32_t ne0;
518 int32_t ne1;
519 uint64_t nb1;
520 int32_t nr0;
521} ggml_metal_kargs_mul_mv_id;
522
523// NORM
524// RMS_NORM
525typedef struct {
526 int32_t ne00;
527 int32_t ne00_t;
528 uint64_t nb1;
529 uint64_t nb2;
530 uint64_t nb3;
531 float eps;
532 int32_t nef1[3];
533 int32_t nef2[3];
534 int32_t nef3[3];
535 uint64_t nbf1[3];
536 uint64_t nbf2[3];
537 uint64_t nbf3[3];
538} ggml_metal_kargs_norm;
539
540typedef struct {
541 int32_t ne00;
542 int32_t ne01;
543 int32_t ne02;
544 int32_t ne03;
545 uint64_t nb00;
546 uint64_t nb01;
547 uint64_t nb02;
548 uint64_t nb03;
549 int32_t ne0;
550 int32_t ne1;
551 int32_t ne2;
552 int32_t ne3;
553 uint64_t nb0;
554 uint64_t nb1;
555 uint64_t nb2;
556 uint64_t nb3;
557 float eps;
558} ggml_metal_kargs_l2_norm;
559
560typedef struct {
561 int64_t ne00;
562 int64_t ne01;
563 int64_t ne02;
564 uint64_t nb00;
565 uint64_t nb01;
566 uint64_t nb02;
567 int32_t ngrp;
568 float eps;
569} ggml_metal_kargs_group_norm;
570
571typedef struct {
572 int32_t IC;
573 int32_t IL;
574 int32_t K;
575 int32_t s0;
576 uint64_t nb0;
577 uint64_t nb1;
578} ggml_metal_kargs_conv_transpose_1d;
579
580typedef struct {
581 int32_t IC;
582 int32_t IH;
583 int32_t IW;
584 int32_t KH;
585 int32_t KW;
586 int32_t OC;
587 int32_t s0;
588 uint64_t nb0;
589 uint64_t nb1;
590 uint64_t nb2;
591} ggml_metal_kargs_conv_transpose_2d;
592
593typedef struct {
594 uint64_t nb00;
595 uint64_t nb01;
596 uint64_t nb02;
597 uint64_t nb03;
598 uint64_t nb10;
599 uint64_t nb11;
600 uint64_t nb12;
601 uint64_t nb13;
602 uint64_t nb0;
603 uint64_t nb1;
604 uint64_t nb2;
605 uint64_t nb3;
606 int32_t IW;
607 int32_t IH;
608 int32_t KW;
609 int32_t KH;
610 int32_t IC;
611 int32_t OC;
612 int32_t OW;
613 int32_t OH;
614 int32_t N;
615 int32_t s0;
616 int32_t s1;
617 int32_t p0;
618 int32_t p1;
619 int32_t d0;
620 int32_t d1;
621} ggml_metal_kargs_conv_2d;
622
623typedef struct {
624 uint64_t ofs0;
625 uint64_t ofs1;
626 int32_t IW;
627 int32_t IH;
628 int32_t CHW;
629 int32_t s0;
630 int32_t s1;
631 int32_t p0;
632 int32_t p1;
633 int32_t d0;
634 int32_t d1;
635 int32_t N;
636 int32_t KH;
637 int32_t KW;
638 int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
639} ggml_metal_kargs_im2col;
640
641typedef struct{
642 int32_t ne00;
643 uint64_t nb01;
644 int32_t ne10;
645 uint64_t nb11;
646 int32_t ne0;
647 uint64_t nb1;
648 int32_t i00;
649 int32_t i10;
650 float alpha;
651 float limit;
652} ggml_metal_kargs_glu;
653
654typedef struct {
655 uint64_t np;
656} ggml_metal_kargs_sum;
657
658typedef struct {
659 int64_t ne00;
660 int64_t ne01;
661 int64_t ne02;
662 int64_t ne03;
663 uint64_t nb00;
664 uint64_t nb01;
665 uint64_t nb02;
666 uint64_t nb03;
667 int64_t ne0;
668 int64_t ne1;
669 int64_t ne2;
670 int64_t ne3;
671 uint64_t nb0;
672 uint64_t nb1;
673 uint64_t nb2;
674 uint64_t nb3;
675} ggml_metal_kargs_sum_rows;
676
677typedef struct {
678 int64_t ne00;
679 int64_t ne01;
680 int64_t ne02;
681 int64_t ne03;
682 uint64_t nb00;
683 uint64_t nb01;
684 uint64_t nb02;
685 uint64_t nb03;
686 int64_t net0;
687 int64_t net1;
688 int64_t net2;
689 int64_t net3;
690 uint64_t nbt0;
691 uint64_t nbt1;
692 uint64_t nbt2;
693 uint64_t nbt3;
694 bool outb;
695} ggml_metal_kargs_cumsum_blk;
696
697typedef struct {
698 int64_t ne00;
699 int64_t ne01;
700 int64_t ne02;
701 int64_t ne03;
702 uint64_t nb00;
703 uint64_t nb01;
704 uint64_t nb02;
705 uint64_t nb03;
706 int64_t net0;
707 int64_t net1;
708 int64_t net2;
709 int64_t net3;
710 uint64_t nbt0;
711 uint64_t nbt1;
712 uint64_t nbt2;
713 uint64_t nbt3;
714} ggml_metal_kargs_cumsum_add;
715
716typedef struct {
717 int32_t ne00;
718 int32_t ne01;
719 int32_t ne02;
720 uint64_t nb01;
721 uint64_t nb02;
722 uint64_t nb03;
723 int32_t ne11;
724 int32_t ne12;
725 int32_t ne13;
726 uint64_t nb11;
727 uint64_t nb12;
728 uint64_t nb13;
729 uint64_t nb1;
730 uint64_t nb2;
731 uint64_t nb3;
732 float scale;
733 float max_bias;
734 float m0;
735 float m1;
736 int32_t n_head_log2;
737} ggml_metal_kargs_soft_max;
738
739typedef struct {
740 int64_t ne00;
741 int64_t ne01;
742 int64_t ne02;
743 uint64_t nb00;
744 uint64_t nb01;
745 uint64_t nb02;
746 int64_t ne10;
747 int64_t ne11;
748 uint64_t nb10;
749 uint64_t nb11;
750 int64_t ne0;
751 int64_t ne1;
752 int64_t ne2;
753 uint64_t nb0;
754 uint64_t nb1;
755 uint64_t nb2;
756} ggml_metal_kargs_ssm_conv;
757
758typedef struct {
759 int64_t d_state;
760 int64_t d_inner;
761 int64_t n_head;
762 int64_t n_group;
763 int64_t n_seq_tokens;
764 int64_t n_seqs;
765 uint64_t s_off;
766 uint64_t nb00;
767 uint64_t nb01;
768 uint64_t nb02;
769 uint64_t nb03;
770 uint64_t nb10;
771 uint64_t nb11;
772 uint64_t nb12;
773 uint64_t ns12;
774 uint64_t nb13;
775 uint64_t nb20;
776 uint64_t nb21;
777 uint64_t ns21;
778 uint64_t nb22;
779 int64_t ne30;
780 uint64_t nb31;
781 uint64_t nb41;
782 uint64_t nb42;
783 uint64_t ns42;
784 uint64_t nb43;
785 uint64_t nb51;
786 uint64_t nb52;
787 uint64_t ns52;
788 uint64_t nb53;
789 uint64_t nb0;
790} ggml_metal_kargs_ssm_scan;
791
792typedef struct {
793 int32_t ne00;
794 int32_t ne01;
795 int32_t ne02;
796 int32_t ne03;
797 uint64_t nb00;
798 uint64_t nb01;
799 uint64_t nb02;
800 uint64_t nb03;
801 int32_t ne10;
802 int32_t ne11;
803 int32_t ne12;
804 int32_t ne13;
805 uint64_t nb10;
806 uint64_t nb11;
807 uint64_t nb12;
808 uint64_t nb13;
809 int32_t ne0;
810 int32_t ne1;
811 int32_t ne2;
812 int32_t ne3;
813 uint64_t nb0;
814 uint64_t nb1;
815 uint64_t nb2;
816 uint64_t nb3;
817} ggml_metal_kargs_solve_tri;
818
819typedef struct {
820 int32_t ne00t;
821 int32_t ne00;
822 uint64_t nb01;
823 uint64_t nb02;
824 uint64_t nb03;
825 int32_t ne10;
826 uint64_t nb10;
827 uint64_t nb11;
828 uint64_t nb12;
829 uint64_t nb1;
830 uint64_t nb2;
831 uint64_t nb3;
832} ggml_metal_kargs_get_rows;
833
834typedef struct {
835 int32_t nk0;
836 int32_t ne01;
837 uint64_t nb01;
838 uint64_t nb02;
839 uint64_t nb03;
840 int32_t ne11;
841 int32_t ne12;
842 uint64_t nb10;
843 uint64_t nb11;
844 uint64_t nb12;
845 uint64_t nb1;
846 uint64_t nb2;
847 uint64_t nb3;
848} ggml_metal_kargs_set_rows;
849
850typedef struct {
851 int32_t ne00;
852 int32_t ne01;
853 int32_t ne02;
854 int32_t ne03;
855 uint64_t nb00;
856 uint64_t nb01;
857 uint64_t nb02;
858 uint64_t nb03;
859 int32_t ne0;
860 int32_t ne1;
861 int32_t ne2;
862 int32_t ne3;
863 uint64_t nb0;
864 uint64_t nb1;
865 uint64_t nb2;
866 uint64_t nb3;
867} ggml_metal_kargs_diag;
868
869typedef struct {
870 int64_t ne00;
871 int64_t ne01;
872 int64_t ne02;
873 int64_t ne03;
874 uint64_t nb00;
875 uint64_t nb01;
876 uint64_t nb02;
877 uint64_t nb03;
878 int64_t ne0;
879 int64_t ne1;
880 int64_t ne2;
881 int64_t ne3;
882 uint64_t nb0;
883 uint64_t nb1;
884 uint64_t nb2;
885 uint64_t nb3;
886 float sf0;
887 float sf1;
888 float sf2;
889 float sf3;
890} ggml_metal_kargs_upscale;
891
892typedef struct {
893 int64_t ne00;
894 int64_t ne01;
895 int64_t ne02;
896 int64_t ne03;
897 uint64_t nb00;
898 uint64_t nb01;
899 uint64_t nb02;
900 uint64_t nb03;
901 int64_t ne0;
902 int64_t ne1;
903 int64_t ne2;
904 int64_t ne3;
905 uint64_t nb0;
906 uint64_t nb1;
907 uint64_t nb2;
908 uint64_t nb3;
909} ggml_metal_kargs_pad;
910
911typedef struct {
912 int64_t ne00;
913 int64_t ne01;
914 int64_t ne02;
915 int64_t ne03;
916 uint64_t nb00;
917 uint64_t nb01;
918 uint64_t nb02;
919 uint64_t nb03;
920 int64_t ne0;
921 int64_t ne1;
922 int64_t ne2;
923 int64_t ne3;
924 uint64_t nb0;
925 uint64_t nb1;
926 uint64_t nb2;
927 uint64_t nb3;
928 int32_t p0;
929 int32_t p1;
930} ggml_metal_kargs_pad_reflect_1d;
931
932typedef struct {
933 uint64_t nb1;
934 int dim;
935 int max_period;
936} ggml_metal_kargs_timestep_embedding;
937
938typedef struct {
939 int32_t ne00;
940 int32_t ne01;
941 int32_t ne02;
942 int32_t ne03;
943 uint64_t nb00;
944 uint64_t nb01;
945 uint64_t nb02;
946 uint64_t nb03;
947 int32_t ne0;
948 int32_t ne1;
949 int32_t ne2;
950 int32_t ne3;
951 uint64_t nb0;
952 uint64_t nb1;
953 uint64_t nb2;
954 uint64_t nb3;
955} ggml_metal_kargs_tri;
956
957typedef struct {
958 int32_t ne00;
959 int32_t ne01;
960 int32_t ne02;
961 int32_t ne03;
962 uint64_t nb00;
963 uint64_t nb01;
964 uint64_t nb02;
965 uint64_t nb03;
966 int32_t ne0;
967 int32_t ne1;
968 int32_t ne2;
969 int32_t ne3;
970 int32_t top_k;
971} ggml_metal_kargs_argsort;
972
973typedef struct {
974 int64_t ne00;
975 int64_t ne01;
976 int64_t ne02;
977 int64_t ne03;
978 uint64_t nb00;
979 uint64_t nb01;
980 uint64_t nb02;
981 uint64_t nb03;
982 int32_t ne0;
983 int32_t ne1;
984 int32_t ne2;
985 int32_t ne3;
986 int32_t top_k;
987 int32_t len;
988} ggml_metal_kargs_argsort_merge;
989
990typedef struct {
991 int64_t ne0;
992 float start;
993 float step;
994} ggml_metal_kargs_arange;
995
996typedef struct {
997 int64_t val;
998} ggml_metal_kargs_memset;
999
1000typedef struct {
1001 int32_t ne00;
1002 int32_t ne01;
1003 int32_t ne02;
1004 int32_t ne03;
1005 uint64_t nb00;
1006 uint64_t nb01;
1007 uint64_t nb02;
1008 uint64_t nb03;
1009 uint64_t nb10;
1010 uint64_t nb11;
1011 uint64_t nb12;
1012 uint64_t nb13;
1013} ggml_metal_kargs_count_equal;
1014
1015typedef struct {
1016 int32_t k0;
1017 int32_t k1;
1018 int32_t s0;
1019 int32_t s1;
1020 int32_t p0;
1021 int32_t p1;
1022 int64_t IH;
1023 int64_t IW;
1024 int64_t OH;
1025 int64_t OW;
1026 int64_t np;
1027} ggml_metal_kargs_pool_2d;
1028
1029typedef struct {
1030 int32_t k0;
1031 int32_t s0;
1032 int32_t p0;
1033 int64_t IW;
1034 int64_t OW;
1035 int64_t np;
1036} ggml_metal_kargs_pool_1d;
1037
1038typedef struct {
1039 int64_t ne00;
1040 uint64_t nb01;
1041} ggml_metal_kargs_argmax;
1042
1043typedef struct {
1044 int64_t np;
1045} ggml_metal_kargs_opt_step_adamw;
1046
1047typedef struct {
1048 int64_t np;
1049} ggml_metal_kargs_opt_step_sgd;
1050
1051#endif // GGML_METAL_IMPL
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp
new file mode 100644
index 0000000..7db95d1
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -0,0 +1,4222 @@
1#include "ggml-metal-ops.h"
2
3#include "ggml.h"
4#include "ggml-impl.h"
5#include "ggml-backend-impl.h"
6
7#include "ggml-metal-impl.h"
8#include "ggml-metal-common.h"
9#include "ggml-metal-device.h"
10
11#include <cassert>
12#include <algorithm>
13#include <limits>
14#include <cmath>
15
16static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
17 if (!t) {
18 return { nullptr, 0 };
19 }
20
21 ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
22
23 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context;
24
25 return ggml_metal_buffer_get_id(ctx, t);
26}
27
28struct ggml_metal_op {
29 ggml_metal_op(
30 ggml_metal_device_t dev,
31 ggml_metal_cmd_buf_t cmd_buf,
32 ggml_cgraph * gf,
33 int idx_start,
34 int idx_end,
35 bool use_fusion,
36 bool use_concurrency,
37 bool use_capture,
38 int debug_graph,
39 int debug_fusion) {
40 this->dev = dev;
41 this->lib = ggml_metal_device_get_library(dev);
42 this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
43 this->mem_ranges = ggml_mem_ranges_init(debug_graph);
44 this->idx_start = idx_start;
45 this->idx_end = idx_end;
46 this->use_fusion = use_fusion;
47 this->use_concurrency = use_concurrency;
48 this->use_capture = use_capture;
49 this->debug_graph = debug_graph;
50 this->debug_fusion = debug_fusion;
51 this->gf = gf;
52
53 idxs.reserve(gf->n_nodes);
54
55 // filter empty nodes
56 // TODO: this can be removed when the allocator starts filtering them earlier
57 // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
58 for (int i = idx_start; i < idx_end; i++) {
59 if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
60 idxs.push_back(i);
61 }
62 }
63 }
64
65 ~ggml_metal_op() {
66 ggml_metal_encoder_end_encoding(this->enc);
67 ggml_metal_encoder_free(this->enc);
68 ggml_mem_ranges_free(this->mem_ranges);
69 }
70
71 int n_nodes() const {
72 return idxs.size();
73 }
74
75 ggml_tensor * node(int i) const {
76 assert(i >= 0 && i < (int) idxs.size());
77 return ggml_graph_node(gf, idxs[i]);
78 }
79
80 bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
81 assert(use_fusion);
82 assert(i0 >= 0 && i0 < n_nodes());
83
84 if (i0 + n_ops > n_nodes()) {
85 return false;
86 }
87
88 return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
89 }
90
91 ggml_metal_device_t dev;
92 ggml_metal_library_t lib;
93 ggml_metal_encoder_t enc;
94 ggml_mem_ranges_t mem_ranges;
95
96 bool use_fusion;
97 bool use_concurrency;
98 bool use_capture;
99
100 int debug_graph;
101 int debug_fusion;
102
103private:
104 ggml_cgraph * gf;
105
106 int idx_start;
107 int idx_end;
108
109 // non-empty node indices
110 std::vector<int> idxs;
111};
112
113ggml_metal_op_t ggml_metal_op_init(
114 ggml_metal_device_t dev,
115 ggml_metal_cmd_buf_t cmd_buf,
116 ggml_cgraph * gf,
117 int idx_start,
118 int idx_end,
119 bool use_fusion,
120 bool use_concurrency,
121 bool use_capture,
122 int debug_graph,
123 int debug_fusion) {
124 ggml_metal_op_t res = new ggml_metal_op(
125 dev,
126 cmd_buf,
127 gf,
128 idx_start,
129 idx_end,
130 use_fusion,
131 use_concurrency,
132 use_capture,
133 debug_graph,
134 debug_fusion);
135
136 return res;
137}
138
139void ggml_metal_op_free(ggml_metal_op_t ctx) {
140 delete ctx;
141}
142
143int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
144 return ctx->n_nodes();
145}
146
147static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
148 if (!ctx->mem_ranges) {
149 return true;
150 }
151
152 ggml_metal_encoder_memory_barrier(ctx->enc);
153
154 ggml_mem_ranges_reset(ctx->mem_ranges);
155
156 return true;
157}
158
159static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) {
160 if (!ctx->mem_ranges) {
161 return false;
162 }
163
164 return ggml_mem_ranges_check(ctx->mem_ranges, node);
165}
166
167static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) {
168 if (!ctx->mem_ranges) {
169 return true;
170 }
171
172 return ggml_mem_ranges_add(ctx->mem_ranges, node);
173}
174
175static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
176 struct ggml_tensor * node = ctx->node(idx);
177
178 //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
179
180 if (ggml_is_empty(node)) {
181 return 1;
182 }
183
184 switch (node->op) {
185 case GGML_OP_NONE:
186 case GGML_OP_RESHAPE:
187 case GGML_OP_VIEW:
188 case GGML_OP_TRANSPOSE:
189 case GGML_OP_PERMUTE:
190 {
191 // noop -> next node
192 if (ctx->debug_graph > 0) {
193 GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
194 }
195 } return 1;
196 default:
197 {
198 } break;
199 }
200
201 if (!ggml_metal_device_supports_op(ctx->dev, node)) {
202 GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node));
203 GGML_ABORT("unsupported op");
204 }
205
206 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
207 return 1;
208 }
209
210 int n_fuse = 1;
211
212 // check if the current node can run concurrently with other nodes before it
213 // the condition is that:
214 // - the current node cannot write to any previous src or dst ranges
215 // - the current node cannot read from any previous dst ranges
216 //
217 // if the condition is not satisfied, we put a memory barrier and clear all ranges
218 // otherwise, we add the new ranges to the encoding context and process the node concurrently
219 //
220 {
221 const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node);
222
223 if (!is_concurrent) {
224 ggml_metal_op_concurrency_reset(ctx);
225 }
226
227 if (ctx->debug_graph > 0) {
228 GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : "");
229 }
230 if (ctx->debug_graph > 1) {
231 GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne);
232 GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb);
233 GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne);
234 GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb);
235 GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne);
236 GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb);
237 GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne);
238 GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb);
239 GGML_TENSOR_LOCALS( int64_t, ne, node, ne);
240 GGML_TENSOR_LOCALS(uint64_t, nb, node, nb);
241
242 if (node->src[0]) {
243 GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
244 ggml_is_contiguous(node->src[0]), node->src[0]->name);
245 }
246 if (node->src[1]) {
247 GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
248 ggml_is_contiguous(node->src[1]), node->src[1]->name);
249 }
250 if (node->src[2]) {
251 GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23,
252 ggml_is_contiguous(node->src[2]), node->src[2]->name);
253 }
254 if (node->src[3]) {
255 GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33,
256 ggml_is_contiguous(node->src[3]), node->src[3]->name);
257 }
258 if (node) {
259 GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
260 node->name);
261 }
262 }
263 }
264
265 switch (node->op) {
266 case GGML_OP_CONCAT:
267 {
268 n_fuse = ggml_metal_op_concat(ctx, idx);
269 } break;
270 case GGML_OP_ADD:
271 case GGML_OP_SUB:
272 case GGML_OP_MUL:
273 case GGML_OP_DIV:
274 {
275 n_fuse = ggml_metal_op_bin(ctx, idx);
276 } break;
277 case GGML_OP_ADD_ID:
278 {
279 n_fuse = ggml_metal_op_add_id(ctx, idx);
280 } break;
281 case GGML_OP_REPEAT:
282 {
283 n_fuse = ggml_metal_op_repeat(ctx, idx);
284 } break;
285 case GGML_OP_ACC:
286 {
287 n_fuse = ggml_metal_op_acc(ctx, idx);
288 } break;
289 case GGML_OP_SCALE:
290 case GGML_OP_FILL:
291 case GGML_OP_CLAMP:
292 case GGML_OP_LEAKY_RELU:
293 case GGML_OP_SQR:
294 case GGML_OP_SQRT:
295 case GGML_OP_SIN:
296 case GGML_OP_COS:
297 case GGML_OP_LOG:
298 case GGML_OP_UNARY:
299 {
300 n_fuse = ggml_metal_op_unary(ctx, idx);
301 } break;
302 case GGML_OP_GLU:
303 {
304 n_fuse = ggml_metal_op_glu(ctx, idx);
305 } break;
306 case GGML_OP_SUM:
307 {
308 n_fuse = ggml_metal_op_sum(ctx, idx);
309 } break;
310 case GGML_OP_SUM_ROWS:
311 case GGML_OP_MEAN:
312 {
313 n_fuse = ggml_metal_op_sum_rows(ctx, idx);
314 } break;
315 case GGML_OP_CUMSUM:
316 {
317 n_fuse = ggml_metal_op_cumsum(ctx, idx);
318 } break;
319 case GGML_OP_SOFT_MAX:
320 {
321 n_fuse = ggml_metal_op_soft_max(ctx, idx);
322 } break;
323 case GGML_OP_SSM_CONV:
324 {
325 n_fuse = ggml_metal_op_ssm_conv(ctx, idx);
326 } break;
327 case GGML_OP_SSM_SCAN:
328 {
329 n_fuse = ggml_metal_op_ssm_scan(ctx, idx);
330 } break;
331 case GGML_OP_RWKV_WKV6:
332 case GGML_OP_RWKV_WKV7:
333 {
334 n_fuse = ggml_metal_op_rwkv(ctx, idx);
335 } break;
336 case GGML_OP_SOLVE_TRI:
337 {
338 n_fuse = ggml_metal_op_solve_tri(ctx, idx);
339 } break;
340 case GGML_OP_MUL_MAT:
341 {
342 n_fuse = ggml_metal_op_mul_mat(ctx, idx);
343 } break;
344 case GGML_OP_MUL_MAT_ID:
345 {
346 n_fuse = ggml_metal_op_mul_mat_id(ctx, idx);
347 } break;
348 case GGML_OP_GET_ROWS:
349 {
350 n_fuse = ggml_metal_op_get_rows(ctx, idx);
351 } break;
352 case GGML_OP_SET_ROWS:
353 {
354 n_fuse = ggml_metal_op_set_rows(ctx, idx);
355 } break;
356 case GGML_OP_DIAG:
357 {
358 n_fuse = ggml_metal_op_diag(ctx, idx);
359 } break;
360 case GGML_OP_L2_NORM:
361 {
362 n_fuse = ggml_metal_op_l2_norm(ctx, idx);
363 } break;
364 case GGML_OP_GROUP_NORM:
365 {
366 n_fuse = ggml_metal_op_group_norm(ctx, idx);
367 } break;
368 case GGML_OP_NORM:
369 case GGML_OP_RMS_NORM:
370 {
371 n_fuse = ggml_metal_op_norm(ctx, idx);
372 } break;
373 case GGML_OP_ROPE:
374 {
375 n_fuse = ggml_metal_op_rope(ctx, idx);
376 } break;
377 case GGML_OP_IM2COL:
378 {
379 n_fuse = ggml_metal_op_im2col(ctx, idx);
380 } break;
381 case GGML_OP_CONV_2D:
382 {
383 n_fuse = ggml_metal_op_conv_2d(ctx, idx);
384 } break;
385 case GGML_OP_CONV_TRANSPOSE_1D:
386 {
387 n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
388 } break;
389 case GGML_OP_CONV_TRANSPOSE_2D:
390 {
391 n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx);
392 } break;
393 case GGML_OP_UPSCALE:
394 {
395 n_fuse = ggml_metal_op_upscale(ctx, idx);
396 } break;
397 case GGML_OP_PAD:
398 {
399 n_fuse = ggml_metal_op_pad(ctx, idx);
400 } break;
401 case GGML_OP_PAD_REFLECT_1D:
402 {
403 n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx);
404 } break;
405 case GGML_OP_ARANGE:
406 {
407 n_fuse = ggml_metal_op_arange(ctx, idx);
408 } break;
409 case GGML_OP_TIMESTEP_EMBEDDING:
410 {
411 n_fuse = ggml_metal_op_timestep_embedding(ctx, idx);
412 } break;
413 case GGML_OP_ARGSORT:
414 {
415 n_fuse = ggml_metal_op_argsort(ctx, idx);
416 } break;
417 case GGML_OP_TOP_K:
418 {
419 n_fuse = ggml_metal_op_top_k(ctx, idx);
420 } break;
421 case GGML_OP_TRI:
422 {
423 n_fuse = ggml_metal_op_tri(ctx, idx);
424 } break;
425 case GGML_OP_FLASH_ATTN_EXT:
426 {
427 n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
428 } break;
429 case GGML_OP_DUP:
430 case GGML_OP_CPY:
431 case GGML_OP_CONT:
432 {
433 n_fuse = ggml_metal_op_cpy(ctx, idx);
434 } break;
435 case GGML_OP_POOL_1D:
436 {
437 n_fuse = ggml_metal_op_pool_1d(ctx, idx);
438 } break;
439 case GGML_OP_POOL_2D:
440 {
441 n_fuse = ggml_metal_op_pool_2d(ctx, idx);
442 } break;
443 case GGML_OP_ARGMAX:
444 {
445 n_fuse = ggml_metal_op_argmax(ctx, idx);
446 } break;
447 case GGML_OP_OPT_STEP_ADAMW:
448 {
449 n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
450 } break;
451 case GGML_OP_OPT_STEP_SGD:
452 {
453 n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
454 } break;
455 case GGML_OP_COUNT_EQUAL:
456 {
457 n_fuse = ggml_metal_op_count_equal(ctx, idx);
458 } break;
459 default:
460 {
461 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
462 GGML_ABORT("fatal error");
463 }
464 }
465
466 if (ctx->debug_graph > 0) {
467 if (n_fuse > 1) {
468 GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
469 }
470 }
471
472 // update the mem ranges in the encoding context
473 for (int i = 0; i < n_fuse; ++i) {
474 if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
475 ggml_metal_op_concurrency_reset(ctx);
476 }
477 }
478
479 return n_fuse;
480}
481
482int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
483 if (ctx->use_capture) {
484 ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
485 }
486
487 int res = ggml_metal_op_encode_impl(ctx, idx);
488 if (idx + res > ctx->n_nodes()) {
489 GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
490 "https://github.com/ggml-org/llama.cpp/pull/14849");
491 }
492
493 if (ctx->use_capture) {
494 ggml_metal_encoder_debug_group_pop(ctx->enc);
495 }
496
497 return res;
498}
499
500int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
501 ggml_tensor * op = ctx->node(idx);
502
503 ggml_metal_library_t lib = ctx->lib;
504 ggml_metal_encoder_t enc = ctx->enc;
505
506 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
507 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
508 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
509 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
510 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
511 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
512
513 const int32_t dim = ((const int32_t *) op->op_params)[0];
514
515 ggml_metal_kargs_concat args = {
516 /*.ne00 =*/ ne00,
517 /*.ne01 =*/ ne01,
518 /*.ne02 =*/ ne02,
519 /*.ne03 =*/ ne03,
520 /*.nb00 =*/ nb00,
521 /*.nb01 =*/ nb01,
522 /*.nb02 =*/ nb02,
523 /*.nb03 =*/ nb03,
524 /*.ne10 =*/ ne10,
525 /*.ne11 =*/ ne11,
526 /*.ne12 =*/ ne12,
527 /*.ne13 =*/ ne13,
528 /*.nb10 =*/ nb10,
529 /*.nb11 =*/ nb11,
530 /*.nb12 =*/ nb12,
531 /*.nb13 =*/ nb13,
532 /*.ne0 =*/ ne0,
533 /*.ne1 =*/ ne1,
534 /*.ne2 =*/ ne2,
535 /*.ne3 =*/ ne3,
536 /*.nb0 =*/ nb0,
537 /*.nb1 =*/ nb1,
538 /*.nb2 =*/ nb2,
539 /*.nb3 =*/ nb3,
540 /*.dim =*/ dim,
541 };
542
543 auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
544
545 ggml_metal_encoder_set_pipeline(enc, pipeline);
546 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
547 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
548 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
549 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
550
551 const int nth = std::min(1024, ne0);
552
553 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
554
555 return 1;
556}
557
558int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
559 ggml_tensor * op = ctx->node(idx);
560
561 ggml_metal_library_t lib = ctx->lib;
562 ggml_metal_encoder_t enc = ctx->enc;
563
564 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
565 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
566 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
567 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
568
569 auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
570
571 ggml_metal_kargs_repeat args = {
572 /*.ne00 =*/ ne00,
573 /*.ne01 =*/ ne01,
574 /*.ne02 =*/ ne02,
575 /*.ne03 =*/ ne03,
576 /*.nb00 =*/ nb00,
577 /*.nb01 =*/ nb01,
578 /*.nb02 =*/ nb02,
579 /*.nb03 =*/ nb03,
580 /*.ne0 =*/ ne0,
581 /*.ne1 =*/ ne1,
582 /*.ne2 =*/ ne2,
583 /*.ne3 =*/ ne3,
584 /*.nb0 =*/ nb0,
585 /*.nb1 =*/ nb1,
586 /*.nb2 =*/ nb2,
587 /*.nb3 =*/ nb3,
588 };
589
590 ggml_metal_encoder_set_pipeline(enc, pipeline);
591 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
592 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
593 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
594
595 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
596
597 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
598
599 return 1;
600}
601
602int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
603 ggml_tensor * op = ctx->node(idx);
604
605 ggml_metal_library_t lib = ctx->lib;
606 ggml_metal_encoder_t enc = ctx->enc;
607
608 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
609 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
610 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
611 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
612 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
613 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
614
615 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
616 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
617 GGML_ASSERT(op->type == GGML_TYPE_F32);
618
619 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
620 GGML_ASSERT(ggml_is_contiguous(op->src[1]));
621
622 const size_t pnb1 = ((const int32_t *) op->op_params)[0];
623 const size_t pnb2 = ((const int32_t *) op->op_params)[1];
624 const size_t pnb3 = ((const int32_t *) op->op_params)[2];
625 const size_t offs = ((const int32_t *) op->op_params)[3];
626
627 const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
628
629 if (!inplace) {
630 // run a separete kernel to cpy src->dst
631 // not sure how to avoid this
632 // TODO: make a simpler cpy_bytes kernel
633
634 //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
635 auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
636
637 ggml_metal_kargs_cpy args = {
638 /*.nk0 =*/ ne00,
639 /*.ne00 =*/ ne00,
640 /*.ne01 =*/ ne01,
641 /*.ne02 =*/ ne02,
642 /*.ne03 =*/ ne03,
643 /*.nb00 =*/ nb00,
644 /*.nb01 =*/ nb01,
645 /*.nb02 =*/ nb02,
646 /*.nb03 =*/ nb03,
647 /*.ne0 =*/ ne0,
648 /*.ne1 =*/ ne1,
649 /*.ne2 =*/ ne2,
650 /*.ne3 =*/ ne3,
651 /*.nb0 =*/ nb0,
652 /*.nb1 =*/ nb1,
653 /*.nb2 =*/ nb2,
654 /*.nb3 =*/ nb3,
655 };
656
657 ggml_metal_encoder_set_pipeline(enc, pipeline);
658 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
659 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
660 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
661
662 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
663
664 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
665
666 ggml_metal_op_concurrency_reset(ctx);
667 }
668
669 ggml_metal_kargs_bin args = {
670 /*.ne00 =*/ ne00,
671 /*.ne01 =*/ ne01,
672 /*.ne02 =*/ ne02,
673 /*.ne03 =*/ ne03,
674 /*.nb00 =*/ nb00,
675 /*.nb01 =*/ pnb1,
676 /*.nb02 =*/ pnb2,
677 /*.nb03 =*/ pnb3,
678 /*.ne10 =*/ ne10,
679 /*.ne11 =*/ ne11,
680 /*.ne12 =*/ ne12,
681 /*.ne13 =*/ ne13,
682 /*.nb10 =*/ nb10,
683 /*.nb11 =*/ nb11,
684 /*.nb12 =*/ nb12,
685 /*.nb13 =*/ nb13,
686 /*.ne0 =*/ ne0,
687 /*.ne1 =*/ ne1,
688 /*.ne2 =*/ ne2,
689 /*.ne3 =*/ ne3,
690 /*.nb0 =*/ nb0,
691 /*.nb1 =*/ pnb1,
692 /*.nb2 =*/ pnb2,
693 /*.nb3 =*/ pnb3,
694 /*.offs =*/ offs,
695 /*.o1 =*/ { 0 },
696 };
697
698 auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
699
700 ggml_metal_encoder_set_pipeline(enc, pipeline);
701 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
702 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
703 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
704 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
705
706 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
707
708 ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1);
709
710 return 1;
711}
712
713int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
714 ggml_tensor * op = ctx->node(idx);
715
716 ggml_metal_library_t lib = ctx->lib;
717 ggml_metal_encoder_t enc = ctx->enc;
718
719 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
720 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
721 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
722 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
723
724 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
725
726 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
727 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
728
729 ggml_metal_kargs_unary args = {
730 /*.ne00 =*/ ne00,
731 /*.ne01 =*/ ne01,
732 /*.ne02 =*/ ne02,
733 /*.ne03 =*/ ne03,
734 /*.nb00 =*/ nb00,
735 /*.nb01 =*/ nb01,
736 /*.nb02 =*/ nb02,
737 /*.nb03 =*/ nb03,
738 /*.ne0 =*/ ne0,
739 /*.ne1 =*/ ne1,
740 /*.ne2 =*/ ne2,
741 /*.ne3 =*/ ne3,
742 /*.nb0 =*/ nb0,
743 /*.nb1 =*/ nb1,
744 /*.nb2 =*/ nb2,
745 /*.nb3 =*/ nb3,
746 /*.slope =*/ 0.0,
747 /*.scale =*/ 0.0,
748 /*.bias =*/ 0.0,
749 /*.val =*/ 0.0,
750 /*.min =*/ 0.0,
751 /*.max =*/ 0.0,
752 };
753
754 if (op->op == GGML_OP_LEAKY_RELU) {
755 args.slope = ggml_get_op_params_f32(op, 0);
756 }
757
758 if (op->op == GGML_OP_SCALE) {
759 args.scale = ggml_get_op_params_f32(op, 0);
760 args.bias = ggml_get_op_params_f32(op, 1);
761 }
762
763 if (op->op == GGML_OP_FILL) {
764 args.val = ggml_get_op_params_f32(op, 0);
765 }
766
767 if (op->op == GGML_OP_CLAMP) {
768 args.min = ggml_get_op_params_f32(op, 0);
769 args.max = ggml_get_op_params_f32(op, 1);
770 }
771
772 auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
773
774 if (pipeline.c4) {
775 args.ne00 = ne00/4;
776 args.ne0 = ne0/4;
777 }
778
779 ggml_metal_encoder_set_pipeline(enc, pipeline);
780 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
781 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
782 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
783
784 if (pipeline.cnt) {
785 const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
786
787 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
788 } else {
789 const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
790
791 const int nth = MIN(args.ne00, nth_max);
792
793 const int nk0 = (args.ne00 + nth - 1)/nth;
794
795 ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
796 }
797
798 return 1;
799}
800
801int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
802 ggml_tensor * op = ctx->node(idx);
803
804 ggml_metal_library_t lib = ctx->lib;
805 ggml_metal_encoder_t enc = ctx->enc;
806
807 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
808 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
809 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
810 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
811 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
812 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
813
814 if (op->src[1]) {
815 GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
816 }
817
818 auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
819
820 const int32_t swp = ggml_get_op_params_i32(op, 1);
821 const float alpha = ggml_get_op_params_f32(op, 2);
822 const float limit = ggml_get_op_params_f32(op, 3);
823
824 const int32_t i00 = swp ? ne0 : 0;
825 const int32_t i10 = swp ? 0 : ne0;
826
827 ggml_metal_kargs_glu args = {
828 /*.ne00 =*/ ne00,
829 /*.nb01 =*/ nb01,
830 /*.ne10 =*/ op->src[1] ? ne10 : ne00,
831 /*.nb11 =*/ op->src[1] ? nb11 : nb01,
832 /*.ne0 =*/ ne0,
833 /*.nb1 =*/ nb1,
834 /*.i00 =*/ op->src[1] ? 0 : i00,
835 /*.i10 =*/ op->src[1] ? 0 : i10,
836 /*.alpha=*/ alpha,
837 /*.limit=*/ limit
838 };
839
840 const int64_t nrows = ggml_nrows(op->src[0]);
841
842 const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2);
843
844 ggml_metal_encoder_set_pipeline(enc, pipeline);
845 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
846 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
847 if (op->src[1]) {
848 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
849 } else {
850 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2);
851 }
852 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
853
854 ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
855
856 return 1;
857}
858
859int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
860 ggml_tensor * op = ctx->node(idx);
861
862 ggml_metal_library_t lib = ctx->lib;
863 ggml_metal_encoder_t enc = ctx->enc;
864
865 const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
866
867 ggml_metal_kargs_sum args = {
868 /*.np =*/ n,
869 };
870
871 auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
872
873 int nth = 32; // SIMD width
874
875 while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
876 nth *= 2;
877 }
878
879 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
880 nth = std::min(nth, (int) n);
881
882 const int nsg = (nth + 31) / 32;
883
884 ggml_metal_encoder_set_pipeline(enc, pipeline);
885 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
886 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
887 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
888
889 ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0);
890
891 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
892
893 return 1;
894}
895
896int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
897 ggml_tensor * op = ctx->node(idx);
898
899 ggml_metal_library_t lib = ctx->lib;
900 ggml_metal_encoder_t enc = ctx->enc;
901
902 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
903 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
904 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
905 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
906
907 ggml_metal_kargs_sum_rows args = {
908 /*.ne00 =*/ ne00,
909 /*.ne01 =*/ ne01,
910 /*.ne02 =*/ ne02,
911 /*.ne03 =*/ ne03,
912 /*.nb00 =*/ nb00,
913 /*.nb01 =*/ nb01,
914 /*.nb02 =*/ nb02,
915 /*.nb03 =*/ nb03,
916 /*.ne0 =*/ ne0,
917 /*.ne1 =*/ ne1,
918 /*.ne2 =*/ ne2,
919 /*.ne3 =*/ ne3,
920 /*.nb0 =*/ nb0,
921 /*.nb1 =*/ nb1,
922 /*.nb2 =*/ nb2,
923 /*.nb3 =*/ nb3,
924 };
925
926 auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
927
928 int nth = 32; // SIMD width
929
930 while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
931 nth *= 2;
932 }
933
934 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
935 nth = std::min(nth, ne00);
936
937 const size_t smem = pipeline.smem;
938
939 ggml_metal_encoder_set_pipeline(enc, pipeline);
940 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
941 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
942 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
943
944 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
945
946 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
947
948 return 1;
949}
950
951int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
952 ggml_tensor * op = ctx->node(idx);
953
954 ggml_metal_library_t lib = ctx->lib;
955 ggml_metal_encoder_t enc = ctx->enc;
956
957 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
958
959 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
960 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
961 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
962 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
963
964 auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
965
966 int nth = 1;
967 while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
968 nth *= 2;
969 }
970
971 GGML_ASSERT(ne00 <= nth*nth);
972
973 const int64_t net0 = (ne00 + nth - 1) / nth;
974 const int64_t net1 = ne01;
975 const int64_t net2 = ne02;
976 const int64_t net3 = ne03;
977
978 const uint64_t nbt0 = sizeof(float);
979 const uint64_t nbt1 = net0*nbt0;
980 const uint64_t nbt2 = net1*nbt1;
981 const uint64_t nbt3 = net2*nbt2;
982
983 const size_t smem = GGML_PAD(32*sizeof(float), 16);
984
985 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
986 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
987
988 ggml_metal_buffer_id bid_tmp = bid_dst;
989 bid_tmp.offs += ggml_nbytes(op);
990
991 {
992 ggml_metal_kargs_cumsum_blk args = {
993 /*.ne00 =*/ ne00,
994 /*.ne01 =*/ ne01,
995 /*.ne02 =*/ ne02,
996 /*.ne03 =*/ ne03,
997 /*.nb00 =*/ nb00,
998 /*.nb01 =*/ nb01,
999 /*.nb02 =*/ nb02,
1000 /*.nb03 =*/ nb03,
1001 /*.net0 =*/ net0,
1002 /*.net1 =*/ net1,
1003 /*.net2 =*/ net2,
1004 /*.net3 =*/ net3,
1005 /*.nbt0 =*/ nbt0,
1006 /*.nbt1 =*/ nbt1,
1007 /*.nbt2 =*/ nbt2,
1008 /*.nbt3 =*/ nbt3,
1009 /*.outb =*/ ne00 > nth,
1010 };
1011
1012 ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1013 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1014 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
1015 ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1016 ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
1017
1018 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1019
1020 ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1021 }
1022
1023 if (ne00 > nth) {
1024 ggml_metal_op_concurrency_reset(ctx);
1025
1026 {
1027 ggml_metal_kargs_cumsum_blk args = {
1028 /*.ne00 =*/ net0,
1029 /*.ne01 =*/ net1,
1030 /*.ne02 =*/ net2,
1031 /*.ne03 =*/ net3,
1032 /*.nb00 =*/ nbt0,
1033 /*.nb01 =*/ nbt1,
1034 /*.nb02 =*/ nbt2,
1035 /*.nb03 =*/ nbt3,
1036 /*.net0 =*/ net0,
1037 /*.net1 =*/ net1,
1038 /*.net2 =*/ net2,
1039 /*.net3 =*/ net3,
1040 /*.nbt0 =*/ nbt0,
1041 /*.nbt1 =*/ nbt1,
1042 /*.nbt2 =*/ nbt2,
1043 /*.nbt3 =*/ nbt3,
1044 /*.outb =*/ false,
1045 };
1046
1047 ggml_metal_encoder_set_pipeline(enc, pipeline_blk);
1048 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1049 ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1050 ggml_metal_encoder_set_buffer (enc, bid_tmp, 2);
1051 ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
1052
1053 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1054
1055 ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1);
1056 }
1057
1058 ggml_metal_op_concurrency_reset(ctx);
1059
1060 {
1061 auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
1062
1063 ggml_metal_kargs_cumsum_add args = {
1064 /*.ne00 =*/ ne00,
1065 /*.ne01 =*/ ne01,
1066 /*.ne02 =*/ ne02,
1067 /*.ne03 =*/ ne03,
1068 /*.nb00 =*/ nb00,
1069 /*.nb01 =*/ nb01,
1070 /*.nb02 =*/ nb02,
1071 /*.nb03 =*/ nb03,
1072 /*.net0 =*/ net0,
1073 /*.net1 =*/ net1,
1074 /*.net2 =*/ net2,
1075 /*.net3 =*/ net3,
1076 /*.nbt0 =*/ nbt0,
1077 /*.nbt1 =*/ nbt1,
1078 /*.nbt2 =*/ nbt2,
1079 /*.nbt3 =*/ nbt3,
1080 };
1081
1082 ggml_metal_encoder_set_pipeline(enc, pipeline_add);
1083 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1084 ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
1085 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
1086
1087 ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1);
1088 }
1089 }
1090
1091 return 1;
1092}
1093
1094int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
1095 ggml_tensor * op = ctx->node(idx);
1096
1097 ggml_metal_library_t lib = ctx->lib;
1098 ggml_metal_encoder_t enc = ctx->enc;
1099
1100 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1101 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1102 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1103 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1104 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1105 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1106
1107 auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
1108
1109 ggml_metal_kargs_get_rows args = {
1110 /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
1111 /*.ne00 =*/ ne00,
1112 /*.nb01 =*/ nb01,
1113 /*.nb02 =*/ nb02,
1114 /*.nb03 =*/ nb03,
1115 /*.ne10 =*/ ne10,
1116 /*.nb10 =*/ nb10,
1117 /*.nb11 =*/ nb11,
1118 /*.nb12 =*/ nb12,
1119 /*.nb1 =*/ nb1,
1120 /*.nb2 =*/ nb2,
1121 /*.nb3 =*/ nb3,
1122 };
1123
1124 const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1125
1126 const int nw0 = (args.ne00t + nth - 1)/nth;
1127
1128 ggml_metal_encoder_set_pipeline(enc, pipeline);
1129 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1130 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1131 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1132 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1133
1134 ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
1135
1136 return 1;
1137}
1138
1139int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
1140 ggml_tensor * op = ctx->node(idx);
1141
1142 ggml_metal_library_t lib = ctx->lib;
1143 ggml_metal_encoder_t enc = ctx->enc;
1144
1145 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1146 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1147 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1148 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1149 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1150 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1151
1152 auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
1153
1154 const int32_t nk0 = ne0/ggml_blck_size(op->type);
1155
1156 int nth = 32; // SIMD width
1157
1158 while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1159 nth *= 2;
1160 }
1161
1162 int nrptg = 1;
1163 if (nth > nk0) {
1164 nrptg = (nth + nk0 - 1)/nk0;
1165 nth = nk0;
1166
1167 if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1168 nrptg--;
1169 }
1170 }
1171
1172 nth = std::min(nth, nk0);
1173
1174 ggml_metal_kargs_set_rows args = {
1175 /*.nk0 =*/ nk0,
1176 /*.ne01 =*/ ne01,
1177 /*.nb01 =*/ nb01,
1178 /*.nb02 =*/ nb02,
1179 /*.nb03 =*/ nb03,
1180 /*.ne11 =*/ ne11,
1181 /*.ne12 =*/ ne12,
1182 /*.nb10 =*/ nb10,
1183 /*.nb11 =*/ nb11,
1184 /*.nb12 =*/ nb12,
1185 /*.nb1 =*/ nb1,
1186 /*.nb2 =*/ nb2,
1187 /*.nb3 =*/ nb3,
1188 };
1189
1190 ggml_metal_encoder_set_pipeline(enc, pipeline);
1191 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1192 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1193 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1194 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1195
1196 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1197
1198 return 1;
1199}
1200
1201int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) {
1202 ggml_tensor * op = ctx->node(idx);
1203
1204 ggml_metal_library_t lib = ctx->lib;
1205 ggml_metal_encoder_t enc = ctx->enc;
1206
1207 GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
1208 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1209 GGML_TENSOR_LOCALS(int32_t, ne, op, ne);
1210 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1211
1212 ggml_metal_kargs_diag args = {
1213 /*.ne00 =*/ne00,
1214 /*.ne01 =*/ne01,
1215 /*.ne02 =*/ne02,
1216 /*.ne03 =*/ne03,
1217 /*.nb00 =*/nb00,
1218 /*.nb01 =*/nb01,
1219 /*.nb02 =*/nb02,
1220 /*.nb03 =*/nb03,
1221 /*.ne0 =*/ne0,
1222 /*.ne1 =*/ne1,
1223 /*.ne2 =*/ne2,
1224 /*.ne3 =*/ne3,
1225 /*.nb0 =*/nb0,
1226 /*.nb1 =*/nb1,
1227 /*.nb2 =*/nb2,
1228 /*.nb3 =*/nb3,
1229 };
1230
1231 auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op);
1232
1233 ggml_metal_encoder_set_pipeline(enc, pipeline);
1234 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1235 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1236 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2);
1237
1238 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1);
1239
1240 return 1;
1241}
1242
1243int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
1244 ggml_tensor * op = ctx->node(idx);
1245
1246 ggml_metal_library_t lib = ctx->lib;
1247 ggml_metal_encoder_t enc = ctx->enc;
1248
1249 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1250 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1251 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1252 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1253 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1254 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1255 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1256 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1257
1258 float scale;
1259 float max_bias;
1260
1261 memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
1262 memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
1263
1264 const uint32_t n_head = op->src[0]->ne[2];
1265 const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
1266
1267 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1268 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1269
1270 // softmax
1271
1272 ggml_metal_kargs_soft_max args = {
1273 /*.ne00 =*/ ne00,
1274 /*.ne01 =*/ ne01,
1275 /*.ne02 =*/ ne02,
1276 /*.nb01 =*/ nb01,
1277 /*.nb02 =*/ nb02,
1278 /*.nb03 =*/ nb03,
1279 /*.ne11 =*/ ne11,
1280 /*.ne12 =*/ ne12,
1281 /*.ne13 =*/ ne13,
1282 /*.nb11 =*/ nb11,
1283 /*.nb12 =*/ nb12,
1284 /*.nb13 =*/ nb13,
1285 /*.nb1 =*/ nb1,
1286 /*.nb2 =*/ nb2,
1287 /*.nb3 =*/ nb3,
1288 /*.scale =*/ scale,
1289 /*.max_bias =*/ max_bias,
1290 /*.m0 =*/ m0,
1291 /*.m1 =*/ m1,
1292 /*.n_head_log2 =*/ n_head_log2,
1293 };
1294
1295 auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
1296
1297 int nth = 32; // SIMD width
1298
1299 if (ne00%4 == 0) {
1300 while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) {
1301 nth *= 2;
1302 }
1303 } else {
1304 while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
1305 nth *= 2;
1306 }
1307 }
1308
1309 const size_t smem = pipeline.smem;
1310
1311 ggml_metal_encoder_set_pipeline(enc, pipeline);
1312 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1313 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1314 if (op->src[1]) {
1315 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1316 } else {
1317 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2);
1318 }
1319 if (op->src[2]) {
1320 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1321 } else {
1322 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
1323 }
1324 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
1325
1326 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1327
1328 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
1329
1330 return 1;
1331}
1332
1333int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
1334 ggml_tensor * op = ctx->node(idx);
1335
1336 ggml_metal_library_t lib = ctx->lib;
1337 ggml_metal_encoder_t enc = ctx->enc;
1338
1339 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1340 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1341 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1342 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1343 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1344 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1345
1346 ggml_metal_kargs_ssm_conv args = {
1347 /*.ne00 =*/ ne00,
1348 /*.ne01 =*/ ne01,
1349 /*.ne02 =*/ ne02,
1350 /*.nb00 =*/ nb00,
1351 /*.nb01 =*/ nb01,
1352 /*.nb02 =*/ nb02,
1353 /*.ne10 =*/ ne10,
1354 /*.ne11 =*/ ne11,
1355 /*.nb10 =*/ nb10,
1356 /*.nb11 =*/ nb11,
1357 /*.ne0 =*/ ne0,
1358 /*.ne1 =*/ ne1,
1359 /*.ne2 =*/ ne2,
1360 /*.nb0 =*/ nb0,
1361 /*.nb1 =*/ nb1,
1362 /*.nb2 =*/ nb2,
1363 };
1364
1365 // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead
1366 const bool use_batched = (ne1 > 1);
1367
1368 if (use_batched) {
1369 // Determine the smallest power of 2 that's >= ne1, but <= 256
1370 int BATCH_SIZE;
1371 if (ne1 > 128) BATCH_SIZE = 256;
1372 else if (ne1 > 64 ) BATCH_SIZE = 128;
1373 else if (ne1 > 32 ) BATCH_SIZE = 64;
1374 else if (ne1 > 16 ) BATCH_SIZE = 32;
1375 else if (ne1 > 8 ) BATCH_SIZE = 16;
1376 else if (ne1 > 4 ) BATCH_SIZE = 8;
1377 else BATCH_SIZE = 2;
1378
1379 auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE);
1380
1381 ggml_metal_encoder_set_pipeline(enc, pipeline);
1382 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1383 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1384 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1385 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1386
1387 // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences
1388 // Each threadgroup has BATCH_SIZE threads, each handling one token
1389 const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE;
1390 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1);
1391 } else {
1392 auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
1393
1394 ggml_metal_encoder_set_pipeline(enc, pipeline);
1395 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
1396 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1397 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1398 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
1399
1400 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
1401 }
1402
1403 return 1;
1404}
1405
1406int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
1407 ggml_tensor * op = ctx->node(idx);
1408
1409 ggml_metal_library_t lib = ctx->lib;
1410 ggml_metal_encoder_t enc = ctx->enc;
1411
1412 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1413 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1414 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1415 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1416 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
1417 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
1418 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
1419 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
1420 GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne);
1421 GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb);
1422 GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne);
1423 GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb);
1424 GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne);
1425 GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb);
1426 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1427 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1428
1429 const ggml_tensor * src3 = op->src[3];
1430 const ggml_tensor * src4 = op->src[4];
1431 const ggml_tensor * src5 = op->src[5];
1432 const ggml_tensor * src6 = op->src[6];
1433
1434 GGML_ASSERT(src3);
1435 GGML_ASSERT(src4);
1436 GGML_ASSERT(src5);
1437 GGML_ASSERT(src6);
1438
1439 const int64_t d_state = ne00;
1440 const int64_t d_inner = ne01;
1441 const int64_t n_head = ne02;
1442 const int64_t n_group = ne41;
1443 const int64_t n_seq_tokens = ne12;
1444 const int64_t n_seqs = ne13;
1445
1446 ggml_metal_kargs_ssm_scan args = {
1447 /*.d_state =*/ d_state,
1448 /*.d_inner =*/ d_inner,
1449 /*.n_head =*/ n_head,
1450 /*.n_group =*/ n_group,
1451 /*.n_seq_tokens =*/ n_seq_tokens,
1452 /*.n_seqs =*/ n_seqs,
1453 /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
1454 /*.nb00 =*/ nb00,
1455 /*.nb01 =*/ nb01,
1456 /*.nb02 =*/ nb02,
1457 /*.nb03 =*/ nb03,
1458 /*.nb10 =*/ nb10,
1459 /*.nb11 =*/ nb11,
1460 /*.nb12 =*/ nb12,
1461 /*.ns12 =*/ nb12/nb10,
1462 /*.nb13 =*/ nb13,
1463 /*.nb20 =*/ nb20,
1464 /*.nb21 =*/ nb21,
1465 /*.ns21 =*/ nb21/nb20,
1466 /*.nb22 =*/ nb22,
1467 /*.ne30 =*/ ne30,
1468 /*.nb31 =*/ nb31,
1469 /*.nb41 =*/ nb41,
1470 /*.nb42 =*/ nb42,
1471 /*.ns42 =*/ nb42/nb40,
1472 /*.nb43 =*/ nb43,
1473 /*.nb51 =*/ nb51,
1474 /*.nb52 =*/ nb52,
1475 /*.ns52 =*/ nb52/nb50,
1476 /*.nb53 =*/ nb53,
1477 /*.nb0 =*/ nb0,
1478 };
1479
1480 auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
1481
1482 GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1483
1484 const size_t smem = pipeline.smem;
1485
1486 ggml_metal_encoder_set_pipeline(enc, pipeline);
1487 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1488 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1489 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1490 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
1491 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4);
1492 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5);
1493 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6);
1494 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
1495 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8);
1496
1497 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1498
1499 ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
1500
1501 return 1;
1502}
1503
1504int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
1505 ggml_tensor * op = ctx->node(idx);
1506
1507 ggml_metal_library_t lib = ctx->lib;
1508 ggml_metal_encoder_t enc = ctx->enc;
1509
1510 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1511 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1512 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1513 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1514
1515 const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1];
1516 const int64_t T = op->src[0]->ne[2];
1517 const int64_t C = op->ne[0];
1518 const int64_t H = op->src[0]->ne[1];
1519
1520 auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
1521
1522 int ida = 0;
1523
1524 ggml_metal_encoder_set_pipeline(enc, pipeline);
1525 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
1526 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
1527 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
1528 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
1529 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
1530 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++);
1531 if (op->op == GGML_OP_RWKV_WKV7) {
1532 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++);
1533 }
1534 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++);
1535 ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++);
1536 ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++);
1537 ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++);
1538 ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++);
1539
1540 ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1);
1541
1542 return 1;
1543}
1544
1545int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
1546 ggml_tensor * op = ctx->node(idx);
1547
1548 ggml_metal_library_t lib = ctx->lib;
1549 ggml_metal_encoder_t enc = ctx->enc;
1550
1551 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1552 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1553 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1554 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1555 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1556 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1557
1558 ggml_metal_kargs_solve_tri args = {
1559 /*.ne00 =*/ ne00,
1560 /*.ne01 =*/ ne01,
1561 /*.ne02 =*/ ne02,
1562 /*.ne03 =*/ ne03,
1563 /*.nb00 =*/ nb00,
1564 /*.nb01 =*/ nb01,
1565 /*.nb02 =*/ nb02,
1566 /*.nb03 =*/ nb03,
1567 /*.ne10 =*/ ne10,
1568 /*.ne11 =*/ ne11,
1569 /*.ne12 =*/ ne12,
1570 /*.ne13 =*/ ne13,
1571 /*.nb10 =*/ nb10,
1572 /*.nb11 =*/ nb11,
1573 /*.nb12 =*/ nb12,
1574 /*.nb13 =*/ nb13,
1575 /*.ne0 =*/ ne0,
1576 /*.ne1 =*/ ne1,
1577 /*.ne2 =*/ ne2,
1578 /*.ne3 =*/ ne3,
1579 /*.nb0 =*/ nb0,
1580 /*.nb1 =*/ nb1,
1581 /*.nb2 =*/ nb2,
1582 /*.nb3 =*/ nb3,
1583 };
1584
1585 auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
1586
1587 ggml_metal_encoder_set_pipeline(enc, pipeline);
1588 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1589 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1590 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1591 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1592
1593 const int nsg = pipeline.nsg;
1594
1595 ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0);
1596
1597 ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1);
1598
1599 return 1;
1600}
1601
1602int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
1603 ggml_tensor * op = ctx->node(idx);
1604
1605 ggml_metal_library_t lib = ctx->lib;
1606 ggml_metal_encoder_t enc = ctx->enc;
1607
1608 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1609 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1610 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1611 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1612
1613 auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
1614
1615 GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
1616
1617 int64_t nk0 = ne00;
1618 if (ggml_is_quantized(op->src[0]->type)) {
1619 nk0 = ne00/16;
1620 } else if (ggml_is_quantized(op->type)) {
1621 nk0 = ne00/ggml_blck_size(op->type);
1622 }
1623
1624 int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1625
1626 // when rows are small, we can batch them together in a single threadgroup
1627 int nrptg = 1;
1628
1629 // TODO: relax this constraint in the future
1630 if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
1631 if (nth > nk0) {
1632 nrptg = (nth + nk0 - 1)/nk0;
1633 nth = nk0;
1634
1635 if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1636 nrptg--;
1637 }
1638 }
1639 }
1640
1641 nth = std::min<int>(nth, nk0);
1642
1643 ggml_metal_kargs_cpy args = {
1644 /*.nk0 =*/ nk0,
1645 /*.ne00 =*/ ne00,
1646 /*.ne01 =*/ ne01,
1647 /*.ne02 =*/ ne02,
1648 /*.ne03 =*/ ne03,
1649 /*.nb00 =*/ nb00,
1650 /*.nb01 =*/ nb01,
1651 /*.nb02 =*/ nb02,
1652 /*.nb03 =*/ nb03,
1653 /*.ne0 =*/ ne0,
1654 /*.ne1 =*/ ne1,
1655 /*.ne2 =*/ ne2,
1656 /*.ne3 =*/ ne3,
1657 /*.nb0 =*/ nb0,
1658 /*.nb1 =*/ nb1,
1659 /*.nb2 =*/ nb2,
1660 /*.nb3 =*/ nb3,
1661 };
1662
1663 const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
1664
1665 ggml_metal_encoder_set_pipeline(enc, pipeline);
1666 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1667 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1668 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1669
1670 ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
1671
1672 return 1;
1673}
1674
1675int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
1676 ggml_tensor * op = ctx->node(idx);
1677
1678 ggml_metal_library_t lib = ctx->lib;
1679 ggml_metal_encoder_t enc = ctx->enc;
1680
1681 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1682 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1683 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1684 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1685
1686 const int32_t * opts = op->op_params;
1687 ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1688
1689 const int32_t k0 = opts[1];
1690 const int32_t s0 = opts[2];
1691 const int32_t p0 = opts[3];
1692
1693 const int64_t IW = op->src[0]->ne[0];
1694 const int64_t OW = op->ne[0];
1695
1696 const int64_t np = ggml_nelements(op);
1697
1698 ggml_metal_kargs_pool_1d args_pool_1d = {
1699 /* .k0 = */ k0,
1700 /* .s0 = */ s0,
1701 /* .p0 = */ p0,
1702 /* .IW = */ IW,
1703 /* .OW = */ OW,
1704 /* .np = */ np
1705 };
1706
1707 auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
1708
1709 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1710 const int ntg = (np + nth - 1) / nth;
1711
1712 ggml_metal_encoder_set_pipeline(enc, pipeline);
1713 ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0);
1714 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1715 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1716
1717 ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1718
1719 return 1;
1720}
1721
1722
1723int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
1724 ggml_tensor * op = ctx->node(idx);
1725
1726 ggml_metal_library_t lib = ctx->lib;
1727 ggml_metal_encoder_t enc = ctx->enc;
1728
1729 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1730 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1731 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1732 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1733
1734 const int32_t * opts = op->op_params;
1735 ggml_op_pool op_pool = (ggml_op_pool) opts[0];
1736
1737 const int32_t k0 = opts[1];
1738 const int32_t k1 = opts[2];
1739 const int32_t s0 = opts[3];
1740 const int32_t s1 = opts[4];
1741 const int32_t p0 = opts[5];
1742 const int32_t p1 = opts[6];
1743
1744 const int64_t IH = op->src[0]->ne[1];
1745 const int64_t IW = op->src[0]->ne[0];
1746
1747 const int64_t N = op->ne[3];
1748 const int64_t OC = op->ne[2];
1749 const int64_t OH = op->ne[1];
1750 const int64_t OW = op->ne[0];
1751
1752 const int64_t np = N * OC * OH * OW;
1753
1754 ggml_metal_kargs_pool_2d args_pool_2d = {
1755 /* .k0 = */ k0,
1756 /* .k1 = */ k1,
1757 /* .s0 = */ s0,
1758 /* .s1 = */ s1,
1759 /* .p0 = */ p0,
1760 /* .p1 = */ p1,
1761 /* .IH = */ IH,
1762 /* .IW = */ IW,
1763 /* .OH = */ OH,
1764 /* .OW = */ OW,
1765 /* .np = */ np
1766 };
1767
1768 auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
1769
1770 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
1771 const int ntg = (np + nth - 1) / nth;
1772
1773 ggml_metal_encoder_set_pipeline(enc, pipeline);
1774 ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0);
1775 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1776 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
1777
1778 ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
1779
1780 return 1;
1781}
1782
1783int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
1784 ggml_tensor * op = ctx->node(idx);
1785
1786 ggml_metal_library_t lib = ctx->lib;
1787 ggml_metal_encoder_t enc = ctx->enc;
1788
1789 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
1790
1791 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
1792 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
1793 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
1794 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
1795 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
1796 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
1797
1798 GGML_ASSERT(ne00 == ne10);
1799
1800 GGML_ASSERT(ne12 % ne02 == 0);
1801 GGML_ASSERT(ne13 % ne03 == 0);
1802
1803 const int16_t r2 = ne12/ne02;
1804 const int16_t r3 = ne13/ne03;
1805
1806 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1807 // to the matrix-vector kernel
1808 const int ne11_mm_min = 8;
1809
1810 // first try to use small-batch mat-mv kernels
1811 // these should be efficient for BS [2, ~8]
1812 if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) &&
1813 (
1814 (
1815 (
1816 op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function
1817 op->src[0]->type == GGML_TYPE_F16 ||
1818 op->src[0]->type == GGML_TYPE_Q4_0 ||
1819 op->src[0]->type == GGML_TYPE_Q4_1 ||
1820 op->src[0]->type == GGML_TYPE_Q5_0 ||
1821 op->src[0]->type == GGML_TYPE_Q5_1 ||
1822 op->src[0]->type == GGML_TYPE_Q8_0 ||
1823 op->src[0]->type == GGML_TYPE_MXFP4 ||
1824 op->src[0]->type == GGML_TYPE_IQ4_NL ||
1825 false) && (ne11 >= 2 && ne11 <= 8)
1826 ) ||
1827 (
1828 (
1829 op->src[0]->type == GGML_TYPE_Q4_K ||
1830 op->src[0]->type == GGML_TYPE_Q5_K ||
1831 op->src[0]->type == GGML_TYPE_Q6_K ||
1832 false) && (ne11 >= 4 && ne11 <= 8)
1833 )
1834 )
1835 ) {
1836 // TODO: determine the optimal parameters based on grid utilization
1837 // I still don't know why we should not always use the maximum available threads:
1838 //
1839 // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
1840 //
1841 // my current hypothesis is that the work grid is not evenly divisible for different nsg
1842 // values and there can be some tail effects when nsg is high. need to confirm this
1843 //
1844 const int nsg = 2; // num simdgroups per threadgroup
1845
1846 // num threads along row per simdgroup
1847 int16_t nxpsg = 0;
1848 if (ne00 % 256 == 0 && ne11 < 3) {
1849 nxpsg = 16;
1850 } else if (ne00 % 128 == 0) {
1851 nxpsg = 8;
1852 } else {
1853 nxpsg = 4;
1854 }
1855
1856 const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
1857 const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup
1858 int16_t r1ptg = 4; // num src1 rows per threadgroup
1859
1860 // note: not sure how optimal are those across all different hardware. there might be someting cleverer
1861 switch (ne11) {
1862 case 2:
1863 r1ptg = 2; break;
1864 case 3:
1865 case 6:
1866 r1ptg = 3; break;
1867 case 4:
1868 case 7:
1869 case 8:
1870 r1ptg = 4; break;
1871 case 5:
1872 r1ptg = 5; break;
1873 default:
1874 GGML_ABORT("unsupported ne11");
1875 };
1876
1877 auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
1878
1879 ggml_metal_kargs_mul_mv_ext args = {
1880 /*.ne00 =*/ ne00,
1881 /*.ne01 =*/ ne01,
1882 /*.ne02 =*/ ne02,
1883 /*.nb00 =*/ nb00,
1884 /*.nb01 =*/ nb01,
1885 /*.nb02 =*/ nb02,
1886 /*.nb03 =*/ nb03,
1887 /*.ne10 =*/ ne10,
1888 /*.ne11 =*/ ne11,
1889 /*.ne12 =*/ ne12,
1890 /*.nb10 =*/ nb10,
1891 /*.nb11 =*/ nb11,
1892 /*.nb12 =*/ nb12,
1893 /*.nb13 =*/ nb13,
1894 /*.ne0 =*/ ne0,
1895 /*.ne1 =*/ ne1,
1896 /*.r2 =*/ r2,
1897 /*.r3 =*/ r3,
1898 };
1899
1900 ggml_metal_encoder_set_pipeline(enc, pipeline);
1901 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1902 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1903 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1904 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1905
1906 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1);
1907 } else if (
1908 !ggml_is_transposed(op->src[0]) &&
1909 !ggml_is_transposed(op->src[1]) &&
1910 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1911 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
1912 props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) {
1913 //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1914
1915 // some Metal matrix data types require aligned pointers
1916 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
1917 //switch (op->src[0]->type) {
1918 // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
1919 // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
1920 // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
1921 // default: break;
1922 //}
1923
1924 auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
1925
1926 ggml_metal_kargs_mul_mm args = {
1927 /*.ne00 =*/ ne00,
1928 /*.ne02 =*/ ne02,
1929 /*.nb01 =*/ nb01,
1930 /*.nb02 =*/ nb02,
1931 /*.nb03 =*/ nb03,
1932 /*.ne12 =*/ ne12,
1933 /*.nb10 =*/ nb10,
1934 /*.nb11 =*/ nb11,
1935 /*.nb12 =*/ nb12,
1936 /*.nb13 =*/ nb13,
1937 /*.ne0 =*/ ne0,
1938 /*.ne1 =*/ ne1,
1939 /*.r2 =*/ r2,
1940 /*.r3 =*/ r3,
1941 };
1942
1943 ggml_metal_encoder_set_pipeline(enc, pipeline);
1944 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1945 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1946 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1947 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1948
1949 const size_t smem = pipeline.smem;
1950
1951 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1952 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
1953 } else {
1954 auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
1955
1956 const int nr0 = pipeline.nr0;
1957 const int nr1 = pipeline.nr1;
1958 const int nsg = pipeline.nsg;
1959
1960 const size_t smem = pipeline.smem;
1961
1962 ggml_metal_kargs_mul_mv args = {
1963 /*.ne00 =*/ ne00,
1964 /*.ne01 =*/ ne01,
1965 /*.ne02 =*/ ne02,
1966 /*.nb00 =*/ nb00,
1967 /*.nb01 =*/ nb01,
1968 /*.nb02 =*/ nb02,
1969 /*.nb03 =*/ nb03,
1970 /*.ne10 =*/ ne10,
1971 /*.ne11 =*/ ne11,
1972 /*.ne12 =*/ ne12,
1973 /*.nb10 =*/ nb10,
1974 /*.nb11 =*/ nb11,
1975 /*.nb12 =*/ nb12,
1976 /*.nb13 =*/ nb13,
1977 /*.ne0 =*/ ne0,
1978 /*.ne1 =*/ ne1,
1979 /*.nr0 =*/ nr0,
1980 /*.r2 =*/ r2,
1981 /*.r3 =*/ r3,
1982 };
1983
1984 ggml_metal_encoder_set_pipeline(enc, pipeline);
1985 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
1986 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
1987 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
1988 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
1989
1990 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
1991
1992 if (op->src[0]->type == GGML_TYPE_F32 ||
1993 op->src[0]->type == GGML_TYPE_F16 ||
1994 op->src[0]->type == GGML_TYPE_BF16 ||
1995 op->src[0]->type == GGML_TYPE_Q8_0) {
1996 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1997 } else {
1998 ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
1999 }
2000 }
2001
2002 return 1;
2003}
2004
2005size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) {
2006 assert(op->op == GGML_OP_MUL_MAT_ID);
2007
2008 const int64_t ne02 = op->src[0]->ne[2]; // n_expert
2009
2010 return ggml_type_size(GGML_TYPE_I32)*ne02;
2011}
2012
2013size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
2014 assert(op->op == GGML_OP_MUL_MAT_ID);
2015
2016 const int64_t ne02 = op->src[0]->ne[2]; // n_expert
2017 const int64_t ne21 = op->src[2]->ne[1]; // n_token
2018
2019 return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
2020}
2021
2022int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
2023 ggml_tensor * op = ctx->node(idx);
2024
2025 ggml_metal_library_t lib = ctx->lib;
2026 ggml_metal_encoder_t enc = ctx->enc;
2027
2028 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
2029
2030 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2031 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2032 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2033 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2034 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2035 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2036 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2037 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2038
2039 // src2 = ids
2040 GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
2041
2042 GGML_ASSERT(!ggml_is_transposed(op->src[0]));
2043 GGML_ASSERT(!ggml_is_transposed(op->src[1]));
2044
2045 GGML_ASSERT(ne03 == 1);
2046 GGML_ASSERT(ne13 == 1);
2047
2048 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2049 ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2050 ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2051 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2052
2053 const uint32_t r2 = 1;
2054 const uint32_t r3 = 1;
2055
2056 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
2057 // to the matrix-vector kernel
2058 // ne20 = n_used_experts
2059 // ne21 = n_rows (batch size)
2060 const int ne21_mm_id_min = 32;
2061
2062 if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
2063 // some Metal matrix data types require aligned pointers
2064 // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
2065 //switch (op->src[0]->type) {
2066 // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
2067 // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
2068 // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
2069 // default: break;
2070 //}
2071
2072 // extra buffers for intermediate id mapping
2073 ggml_metal_buffer_id bid_tpe = bid_dst;
2074 bid_tpe.offs += ggml_nbytes(op);
2075
2076 ggml_metal_buffer_id bid_ids = bid_tpe;
2077 bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op);
2078
2079 {
2080 ggml_metal_kargs_mul_mm_id_map0 args = {
2081 ne02,
2082 ne10,
2083 ne11, // n_expert_used (bcast)
2084 nb11,
2085 nb12,
2086 ne21, // n_tokens
2087 ne20, // n_expert_used
2088 nb21,
2089 };
2090
2091 auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
2092
2093 const size_t smem = pipeline.smem;
2094
2095 GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2096
2097 GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2098
2099 ggml_metal_encoder_set_pipeline(enc, pipeline);
2100 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2101 ggml_metal_encoder_set_buffer (enc, bid_src2, 1);
2102 ggml_metal_encoder_set_buffer (enc, bid_tpe, 2);
2103 ggml_metal_encoder_set_buffer (enc, bid_ids, 3);
2104
2105 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2106
2107 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1);
2108 }
2109
2110 // this barrier is always needed because the next kernel has to wait for the id maps to be computed
2111 ggml_metal_op_concurrency_reset(ctx);
2112
2113 {
2114 auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
2115
2116 ggml_metal_kargs_mul_mm_id args = {
2117 /*.ne00 =*/ ne00,
2118 /*.ne02 =*/ ne02,
2119 /*.nb01 =*/ nb01,
2120 /*.nb02 =*/ nb02,
2121 /*.nb03 =*/ nb03,
2122 /*.ne11 =*/ ne11, // n_expert_used (bcast)
2123 /*.nb10 =*/ nb10,
2124 /*.nb11 =*/ nb11,
2125 /*.nb12 =*/ nb12,
2126 /*.nb13 =*/ nb13,
2127 /*.ne20 =*/ ne20, // n_expert_used
2128 /*.ne21 =*/ ne21, // n_tokens
2129 /*.ne0 =*/ ne0,
2130 /*.ne1 =*/ ne1,
2131 /*.r2 =*/ r2,
2132 /*.r3 =*/ r3,
2133 };
2134
2135 ggml_metal_encoder_set_pipeline(enc, pipeline);
2136 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2137 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2138 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2139 ggml_metal_encoder_set_buffer (enc, bid_tpe, 3);
2140 ggml_metal_encoder_set_buffer (enc, bid_ids, 4);
2141 ggml_metal_encoder_set_buffer (enc, bid_dst, 5);
2142
2143 const size_t smem = pipeline.smem;
2144
2145 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2146
2147 ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
2148 }
2149 } else {
2150 auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
2151
2152 const int nr0 = pipeline.nr0;
2153 const int nr1 = pipeline.nr1;
2154 const int nsg = pipeline.nsg;
2155
2156 const size_t smem = pipeline.smem;
2157
2158 ggml_metal_kargs_mul_mv_id args = {
2159 /*.nei0 =*/ ne20,
2160 /*.nei1 =*/ ne21,
2161 /*.nbi1 =*/ nb21,
2162 /*.ne00 =*/ ne00,
2163 /*.ne01 =*/ ne01,
2164 /*.ne02 =*/ ne02,
2165 /*.nb00 =*/ nb00,
2166 /*.nb01 =*/ nb01,
2167 /*.nb02 =*/ nb02,
2168 /*.ne10 =*/ ne10,
2169 /*.ne11 =*/ ne11,
2170 /*.ne12 =*/ ne12,
2171 /*.ne13 =*/ ne13,
2172 /*.nb10 =*/ nb10,
2173 /*.nb11 =*/ nb11,
2174 /*.nb12 =*/ nb12,
2175 /*.ne0 =*/ ne0,
2176 /*.ne1 =*/ ne1,
2177 /*.nb1 =*/ nb1,
2178 /*.nr0 =*/ nr0,
2179 };
2180
2181 if (ggml_is_quantized(op->src[0]->type)) {
2182 GGML_ASSERT(ne00 >= nsg*nr0);
2183 }
2184
2185 ggml_metal_encoder_set_pipeline(enc, pipeline);
2186 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
2187 ggml_metal_encoder_set_buffer(enc, bid_src0, 1);
2188 ggml_metal_encoder_set_buffer(enc, bid_src1, 2);
2189 ggml_metal_encoder_set_buffer(enc, bid_dst, 3);
2190 ggml_metal_encoder_set_buffer(enc, bid_src2, 4);
2191
2192 const int64_t _ne1 = 1;
2193 const int64_t ne123 = ne20*ne21;
2194
2195 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2196
2197 if (op->src[0]->type == GGML_TYPE_F32 ||
2198 op->src[0]->type == GGML_TYPE_F16 ||
2199 op->src[0]->type == GGML_TYPE_BF16 ||
2200 op->src[0]->type == GGML_TYPE_Q8_0) {
2201 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2202 } else {
2203 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
2204 }
2205 }
2206
2207 return 1;
2208}
2209
2210int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
2211 ggml_tensor * op = ctx->node(idx);
2212
2213 ggml_metal_library_t lib = ctx->lib;
2214 ggml_metal_encoder_t enc = ctx->enc;
2215
2216 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2217 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2218 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2219 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2220 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2221 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2222 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2223
2224 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2225 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2226 GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32);
2227 GGML_ASSERT(op->type == GGML_TYPE_F32);
2228
2229 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2230
2231 ggml_metal_kargs_add_id args = {
2232 /*.ne0 =*/ ne0,
2233 /*.ne1 =*/ ne1,
2234 /*.nb01 =*/ nb01,
2235 /*.nb02 =*/ nb02,
2236 /*.nb11 =*/ nb11,
2237 /*.nb21 =*/ nb21,
2238 };
2239
2240 auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
2241
2242 ggml_metal_encoder_set_pipeline(enc, pipeline);
2243 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2244 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
2245 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
2246 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
2247 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
2248
2249 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
2250
2251 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1);
2252
2253 return 1;
2254}
2255
2256bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) {
2257 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2258
2259 const int64_t ne00 = op->src[0]->ne[0]; // head size
2260 const int64_t ne01 = op->src[0]->ne[1]; // batch size
2261
2262 // use vec kernel if the batch size is small and if the head size is supported
2263 return (ne01 < 20) && (ne00 % 32 == 0);
2264}
2265
2266size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
2267 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2268
2269 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2270 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2271 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2272 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2273 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2274 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2275 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2276 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2277
2278 size_t res = 0;
2279
2280 const bool has_mask = op->src[3] != nullptr;
2281
2282 // note: the non-vec kernel requires more extra memory, so always reserve for it
2283 GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
2284
2285 //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2286 if (false) {
2287 // note: always reserve the padding space to avoid graph reallocations
2288 //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
2289 const bool has_kvpad = true;
2290
2291 if (has_kvpad) {
2292 res += OP_FLASH_ATTN_EXT_VEC_NCPSG*(
2293 nb11*ne12*ne13 +
2294 nb21*ne22*ne23 +
2295 (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2296 }
2297 } else {
2298 //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0;
2299 const bool has_kvpad = true;
2300
2301 if (has_kvpad) {
2302 res += OP_FLASH_ATTN_EXT_NCPSG*(
2303 nb11*ne12*ne13 +
2304 nb21*ne22*ne23 +
2305 (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0));
2306 }
2307 }
2308
2309 return res;
2310}
2311
2312size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) {
2313 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2314
2315 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2316 //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2317 //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2318 //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2319 //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2320 //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2321 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2322 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2323
2324 size_t res = 0;
2325
2326 const bool has_mask = op->src[3] != nullptr;
2327
2328 if (!has_mask) {
2329 return res;
2330 }
2331
2332 const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op);
2333
2334 // this optimization is not useful for the vector kernels
2335 // note: always reserve the blk buffer to avoid graph reallocations
2336 //if (is_vec) {
2337 // return res;
2338 //}
2339
2340 const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG;
2341 const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG;
2342
2343 const int64_t ne1 = (ne01 + nqptg - 1)/nqptg;
2344 const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg;
2345
2346 res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32);
2347
2348 return res;
2349}
2350
2351size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
2352 assert(op->op == GGML_OP_FLASH_ATTN_EXT);
2353
2354 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2355 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2356 //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2357 //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2358 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2359 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2360 //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2361 //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2362
2363 size_t res = 0;
2364
2365 // note: always reserve the temp buffer to avoid graph reallocations
2366 //if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
2367 if (true) {
2368 const int64_t nwg = 32;
2369 const int64_t ne01_max = std::min(ne01, 32);
2370
2371 // temp buffer for writing the results from each workgroup
2372 // - ne20: the size of the Value head
2373 // - + 2: the S and M values for each intermediate result
2374 res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2));
2375 }
2376
2377 return res;
2378}
2379
2380int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
2381 ggml_tensor * op = ctx->node(idx);
2382
2383 ggml_metal_library_t lib = ctx->lib;
2384 ggml_metal_encoder_t enc = ctx->enc;
2385
2386 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev);
2387
2388 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2389 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2390 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2391 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2392 GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
2393 GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
2394 GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne);
2395 GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb);
2396 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2397 GGML_TENSOR_LOCALS( int32_t, nb, op, nb);
2398
2399 GGML_ASSERT(ne00 % 4 == 0);
2400
2401 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2402 GGML_ASSERT(op->src[1]->type == op->src[2]->type);
2403
2404 //GGML_ASSERT(ggml_are_same_shape (src1, src2));
2405 GGML_ASSERT(ne11 == ne21);
2406 GGML_ASSERT(ne12 == ne22);
2407
2408 GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16);
2409 GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] &&
2410 "the Flash-Attention Metal kernel requires the mask to be at least n_queries big");
2411
2412 float scale;
2413 float max_bias;
2414 float logit_softcap;
2415
2416 memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale));
2417 memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias));
2418 memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap));
2419
2420 if (logit_softcap != 0.0f) {
2421 scale /= logit_softcap;
2422 }
2423
2424 const bool has_mask = op->src[3] != NULL;
2425 const bool has_sinks = op->src[4] != NULL;
2426 const bool has_bias = max_bias != 0.0f;
2427 const bool has_scap = logit_softcap != 0.0f;
2428
2429 const uint32_t n_head = op->src[0]->ne[2];
2430 const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2431
2432 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2433 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
2434
2435 GGML_ASSERT(ne01 < 65536);
2436
2437 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2438 ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2439 ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
2440 ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
2441 ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
2442
2443 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2444
2445 ggml_metal_buffer_id bid_pad = bid_dst;
2446 bid_pad.offs += ggml_nbytes(op);
2447
2448 ggml_metal_buffer_id bid_blk = bid_pad;
2449 bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
2450
2451 ggml_metal_buffer_id bid_tmp = bid_blk;
2452 bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op);
2453
2454 if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
2455 // half8x8 kernel
2456 const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup
2457 const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup
2458
2459 GGML_ASSERT(nqptg <= 32);
2460 GGML_ASSERT(nqptg % 8 == 0);
2461 GGML_ASSERT(ncpsg % 32 == 0);
2462
2463 bool need_sync = false;
2464
2465 const bool has_kvpad = ne11 % ncpsg != 0;
2466
2467 if (has_kvpad) {
2468 assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2469
2470 ggml_metal_kargs_flash_attn_ext_pad args0 = {
2471 /*.ne11 =*/ne11,
2472 /*.ne_12_2 =*/ne12,
2473 /*.ne_12_3 =*/ne13,
2474 /*.nb11 =*/nb11,
2475 /*.nb12 =*/nb12,
2476 /*.nb13 =*/nb13,
2477 /*.nb21 =*/nb21,
2478 /*.nb22 =*/nb22,
2479 /*.nb23 =*/nb23,
2480 /*.ne31 =*/ne31,
2481 /*.ne32 =*/ne32,
2482 /*.ne33 =*/ne33,
2483 /*.nb31 =*/nb31,
2484 /*.nb32 =*/nb32,
2485 /*.nb33 =*/nb33,
2486 };
2487
2488 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2489
2490 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2491 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2492 ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2493 ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2494 ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2495 ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2496
2497 assert(ne12 == ne22);
2498 assert(ne13 == ne23);
2499
2500 ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2501
2502 need_sync = true;
2503 }
2504
2505 if (has_mask) {
2506 assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0);
2507
2508 ggml_metal_kargs_flash_attn_ext_blk args0 = {
2509 /*.ne01 =*/ ne01,
2510 /*.ne30 =*/ ne30,
2511 /*.ne31 =*/ ne31,
2512 /*.ne32 =*/ ne32,
2513 /*.ne33 =*/ ne33,
2514 /*.nb31 =*/ nb31,
2515 /*.nb32 =*/ nb32,
2516 /*.nb33 =*/ nb33,
2517 };
2518
2519 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
2520
2521 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2522 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2523 ggml_metal_encoder_set_buffer (enc, bid_src3, 1);
2524 ggml_metal_encoder_set_buffer (enc, bid_blk, 2);
2525
2526 const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg);
2527 const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg);
2528
2529 ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1);
2530
2531 need_sync = true;
2532 }
2533
2534 if (need_sync) {
2535 ggml_metal_op_concurrency_reset(ctx);
2536 }
2537
2538 const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0;
2539
2540 // 2*(2*ncpsg)
2541 // ncpsg soft_max values + ncpsg mask values
2542 //
2543 // 16*32*(nsg)
2544 // the shared memory needed for the simdgroups to load the KV cache
2545 // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
2546 //
2547#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
2548
2549 //int64_t nsgmax = 4;
2550 //
2551 //if (is_q) {
2552 // nsgmax = 2;
2553 // while (true) {
2554 // const size_t smem = FATTN_SMEM(nsgmax);
2555 // if (smem > props_dev->max_theadgroup_memory_size) {
2556 // break;
2557 // }
2558 // nsgmax *= 2;
2559 // }
2560 // nsgmax /= 2;
2561 //}
2562
2563 // simdgroups per threadgroup (a.k.a. warps)
2564 //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
2565 int32_t nsg = ne00 >= 512 ? 8 : 4;
2566
2567 const size_t smem = FATTN_SMEM(nsg);
2568
2569 ggml_metal_kargs_flash_attn_ext args = {
2570 /*.ne01 =*/ ne01,
2571 /*.ne02 =*/ ne02,
2572 /*.ne03 =*/ ne03,
2573 /*.nb01 =*/ nb01,
2574 /*.nb02 =*/ nb02,
2575 /*.nb03 =*/ nb03,
2576 /*.ne11 =*/ ne11,
2577 /*.ne_12_2 =*/ ne12,
2578 /*.ne_12_3 =*/ ne13,
2579 /*.ns10 =*/ int32_t(nb11/nb10),
2580 /*.nb11 =*/ nb11,
2581 /*.nb12 =*/ nb12,
2582 /*.nb13 =*/ nb13,
2583 /*.ns20 =*/ int32_t(nb21/nb20),
2584 /*.nb21 =*/ nb21,
2585 /*.nb22 =*/ nb22,
2586 /*.nb23 =*/ nb23,
2587 /*.ne31 =*/ ne31,
2588 /*.ne32 =*/ ne32,
2589 /*.ne33 =*/ ne33,
2590 /*.nb31 =*/ nb31,
2591 /*.nb32 =*/ nb32,
2592 /*.nb33 =*/ nb33,
2593 /*.ne1 =*/ ne1,
2594 /*.ne2 =*/ ne2,
2595 /*.ne3 =*/ ne3,
2596 /*.scale =*/ scale,
2597 /*.max_bias =*/ max_bias,
2598 /*.m0 =*/ m0,
2599 /*.m1 =*/ m1,
2600 /*.n_head_log2 =*/ n_head_log2,
2601 /*.logit_softcap =*/ logit_softcap,
2602 };
2603
2604 auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
2605
2606 ggml_metal_encoder_set_pipeline(enc, pipeline);
2607 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2608 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2609 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2610 ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2611 ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2612 ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2613 ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
2614 ggml_metal_encoder_set_buffer (enc, bid_blk, 7);
2615 ggml_metal_encoder_set_buffer (enc, bid_dst, 8);
2616
2617 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2618
2619 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1);
2620#undef FATTN_SMEM
2621 } else {
2622 // half4x4 kernel
2623 const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup
2624 const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !!
2625 const int nhptg = 1; // heads per threadgroup
2626
2627 GGML_ASSERT(nqptg <= 32);
2628 GGML_ASSERT(nqptg % 1 == 0);
2629 GGML_ASSERT(ncpsg % 32 == 0);
2630
2631 bool need_sync = false;
2632
2633 const bool has_kvpad = ne11 % ncpsg != 0;
2634
2635 if (has_kvpad) {
2636 assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0);
2637
2638 ggml_metal_kargs_flash_attn_ext_pad args0 = {
2639 /*.ne11 =*/ne11,
2640 /*.ne_12_2 =*/ne12,
2641 /*.ne_12_3 =*/ne13,
2642 /*.nb11 =*/nb11,
2643 /*.nb12 =*/nb12,
2644 /*.nb13 =*/nb13,
2645 /*.nb21 =*/nb21,
2646 /*.nb22 =*/nb22,
2647 /*.nb23 =*/nb23,
2648 /*.ne31 =*/ne31,
2649 /*.ne32 =*/ne32,
2650 /*.ne33 =*/ne33,
2651 /*.nb31 =*/nb31,
2652 /*.nb32 =*/nb32,
2653 /*.nb33 =*/nb33,
2654 };
2655
2656 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
2657
2658 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2659 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2660 ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
2661 ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
2662 ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
2663 ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
2664
2665 assert(ne12 == ne22);
2666 assert(ne13 == ne23);
2667
2668 ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1);
2669
2670 need_sync = true;
2671 }
2672
2673 if (need_sync) {
2674 ggml_metal_op_concurrency_reset(ctx);
2675 }
2676
2677 // note: for simplicity assume the K is larger or equal than V
2678 GGML_ASSERT(ne10 >= ne20);
2679
2680 // ne00 + 2*ncpsg*(nsg)
2681 // for each query, we load it as f16 in shared memory (ne00)
2682 // and store the soft_max values and the mask
2683 //
2684 // ne20*(nsg)
2685 // each simdgroup has a full f32 head vector in shared mem to accumulate results
2686 //
2687#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16))
2688
2689 int64_t nsg = 1;
2690
2691 // workgroups
2692 // each workgroup handles nsg*nkpsg cache values
2693 int32_t nwg = 1;
2694 if (false) {
2695 // for small KV caches, we could launch a single workgroup and write the results directly to dst/
2696 // however, this does not lead to significant improvement, so disabled
2697 nwg = 1;
2698 nsg = 4;
2699 } else {
2700 nwg = 32;
2701 nsg = 1;
2702 while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) {
2703 nsg *= 2;
2704 }
2705 }
2706
2707 ggml_metal_kargs_flash_attn_ext_vec args = {
2708 /*.ne01 =*/ ne01,
2709 /*.ne02 =*/ ne02,
2710 /*.ne03 =*/ ne03,
2711 /*.nb01 =*/ nb01,
2712 /*.nb02 =*/ nb02,
2713 /*.nb03 =*/ nb03,
2714 /*.ne11 =*/ ne11,
2715 /*.ne_12_2 =*/ ne12,
2716 /*.ne_12_3 =*/ ne13,
2717 /*.ns10 =*/ int32_t(nb11/nb10),
2718 /*.nb11 =*/ nb11,
2719 /*.nb12 =*/ nb12,
2720 /*.nb13 =*/ nb13,
2721 /*.ns20 =*/ int32_t(nb21/nb20),
2722 /*.nb21 =*/ nb21,
2723 /*.nb22 =*/ nb22,
2724 /*.nb23 =*/ nb23,
2725 /*.ne31 =*/ ne31,
2726 /*.ne32 =*/ ne32,
2727 /*.ne33 =*/ ne33,
2728 /*.nb31 =*/ nb31,
2729 /*.nb32 =*/ nb32,
2730 /*.nb33 =*/ nb33,
2731 /*.ne1 =*/ ne1,
2732 /*.ne2 =*/ ne2,
2733 /*.ne3 =*/ ne3,
2734 /*.scale =*/ scale,
2735 /*.max_bias =*/ max_bias,
2736 /*.m0 =*/ m0,
2737 /*.m1 =*/ m1,
2738 /*.n_head_log2 =*/ n_head_log2,
2739 /*.logit_softcap =*/ logit_softcap,
2740 };
2741
2742 auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
2743
2744 GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2745
2746 ggml_metal_encoder_set_pipeline(enc, pipeline);
2747 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2748 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2749 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2750 ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
2751 ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
2752 ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
2753
2754 const size_t smem = FATTN_SMEM(nsg);
2755
2756 //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax);
2757 GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
2758
2759 if (nwg == 1) {
2760 assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
2761
2762 // using 1 workgroup -> write the result directly into dst
2763 ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2764 ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
2765
2766 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2767
2768 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2769 } else {
2770 // sanity checks
2771 assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
2772
2773 GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
2774 GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
2775
2776 // write the results from each workgroup into a temp buffer
2777 ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
2778 ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
2779
2780 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
2781 ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1);
2782
2783 // sync the 2 kernels
2784 ggml_metal_op_concurrency_reset(ctx);
2785
2786 // reduce the results from the workgroups
2787 {
2788 const int32_t nrows = ne1*ne2*ne3;
2789
2790 ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
2791 nrows,
2792 };
2793
2794 auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
2795
2796 ggml_metal_encoder_set_pipeline(enc, pipeline0);
2797 ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
2798 ggml_metal_encoder_set_buffer (enc, bid_tmp, 1);
2799 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
2800
2801 ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1);
2802 }
2803 }
2804#undef FATTN_SMEM
2805 }
2806
2807 return 1;
2808}
2809
2810int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
2811 ggml_tensor * op = ctx->node(idx);
2812
2813 ggml_metal_library_t lib = ctx->lib;
2814 ggml_metal_encoder_t enc = ctx->enc;
2815
2816 const bool use_fusion = ctx->use_fusion;
2817
2818 const int debug_fusion = ctx->debug_fusion;
2819
2820 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2821 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2822 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
2823 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
2824 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2825 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2826
2827 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
2828 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
2829
2830 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2831 GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
2832
2833 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2834 ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
2835 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2836
2837 ggml_metal_kargs_bin args = {
2838 /*.ne00 =*/ ne00,
2839 /*.ne01 =*/ ne01,
2840 /*.ne02 =*/ ne02,
2841 /*.ne03 =*/ ne03,
2842 /*.nb00 =*/ nb00,
2843 /*.nb01 =*/ nb01,
2844 /*.nb02 =*/ nb02,
2845 /*.nb03 =*/ nb03,
2846 /*.ne10 =*/ ne10,
2847 /*.ne11 =*/ ne11,
2848 /*.ne12 =*/ ne12,
2849 /*.ne13 =*/ ne13,
2850 /*.nb10 =*/ nb10,
2851 /*.nb11 =*/ nb11,
2852 /*.nb12 =*/ nb12,
2853 /*.nb13 =*/ nb13,
2854 /*.ne0 =*/ ne0,
2855 /*.ne1 =*/ ne1,
2856 /*.ne2 =*/ ne2,
2857 /*.ne3 =*/ ne3,
2858 /*.nb0 =*/ nb0,
2859 /*.nb1 =*/ nb1,
2860 /*.nb2 =*/ nb2,
2861 /*.nb3 =*/ nb3,
2862 /*.offs =*/ 0,
2863 /*.o1 =*/ { bid_src1.offs },
2864 };
2865
2866 ggml_op fops[8];
2867
2868 int n_fuse = 1;
2869
2870 // c[0] = add(a, b[0])
2871 // c[1] = add(c[0], b[1])
2872 // c[2] = add(c[1], b[2])
2873 // ...
2874 if (use_fusion) {
2875 fops[0] = GGML_OP_ADD;
2876 fops[1] = GGML_OP_ADD;
2877 fops[2] = GGML_OP_ADD;
2878 fops[3] = GGML_OP_ADD;
2879 fops[4] = GGML_OP_ADD;
2880 fops[5] = GGML_OP_ADD;
2881 fops[6] = GGML_OP_ADD;
2882 fops[7] = GGML_OP_ADD;
2883
2884 // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
2885 // across splits. idx_end indicates the last node in the current split
2886 for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
2887 if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
2888 break;
2889 }
2890
2891 ggml_tensor * f0 = ctx->node(idx + n_fuse);
2892 ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
2893
2894 if (f0 != f1->src[0]) {
2895 break;
2896 }
2897
2898 // b[0] === b[1] === ...
2899 if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
2900 break;
2901 }
2902
2903 // only fuse ops if src1 is in the same Metal buffer
2904 ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
2905 if (bid_fuse.metal != bid_src1.metal) {
2906 break;
2907 }
2908
2909 //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
2910
2911 args.o1[n_fuse + 1] = bid_fuse.offs;
2912 }
2913
2914 ++n_fuse;
2915
2916 if (debug_fusion > 1 && n_fuse > 1) {
2917 GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse);
2918 }
2919 }
2920
2921 // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
2922 bid_src1.offs = 0;
2923
2924 struct ggml_metal_pipeline_with_params pipeline;
2925
2926 pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
2927
2928 if (n_fuse > 1) {
2929 bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
2930
2931 for (int i = 1; i < n_fuse; ++i) {
2932 if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
2933 ggml_metal_op_concurrency_reset(ctx);
2934
2935 break;
2936 }
2937 }
2938 }
2939
2940 if (pipeline.c4) {
2941 args.ne00 = ne00/4;
2942 args.ne10 = ne10/4;
2943 args.ne0 = ne0/4;
2944 }
2945
2946 ggml_metal_encoder_set_pipeline(enc, pipeline);
2947 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
2948 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
2949 ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
2950 ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
2951
2952 if (pipeline.cnt) {
2953 const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
2954
2955 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
2956 } else {
2957 const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
2958
2959 int nth = 1;
2960
2961 while (2*nth < args.ne0 && nth < nth_max) {
2962 nth *= 2;
2963 }
2964
2965 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
2966 }
2967
2968 return n_fuse;
2969}
2970
2971int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
2972 ggml_tensor * op = ctx->node(idx);
2973
2974 ggml_metal_library_t lib = ctx->lib;
2975 ggml_metal_encoder_t enc = ctx->enc;
2976
2977 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
2978 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
2979 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
2980 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
2981
2982 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
2983
2984 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
2985 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
2986
2987 float eps;
2988 memcpy(&eps, op->op_params, sizeof(float));
2989
2990 ggml_metal_kargs_l2_norm args = {
2991 /*.ne00 =*/ ne00,
2992 /*.ne01 =*/ ne01,
2993 /*.ne02 =*/ ne02,
2994 /*.ne03 =*/ ne03,
2995 /*.nb00 =*/ nb00,
2996 /*.nb01 =*/ nb01,
2997 /*.nb02 =*/ nb02,
2998 /*.nb03 =*/ nb03,
2999 /*.ne0 =*/ ne0,
3000 /*.ne1 =*/ ne1,
3001 /*.ne2 =*/ ne2,
3002 /*.ne3 =*/ ne3,
3003 /*.nb0 =*/ nb0,
3004 /*.nb1 =*/ nb1,
3005 /*.nb2 =*/ nb2,
3006 /*.nb3 =*/ nb3,
3007 /*.eps =*/ eps,
3008 };
3009
3010 auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
3011
3012 if (pipeline.c4) {
3013 args.ne00 = ne00/4;
3014 args.ne0 = ne0/4;
3015 }
3016
3017 int nth = 32; // SIMD width
3018
3019 while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3020 nth *= 2;
3021 }
3022
3023 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3024
3025 const size_t smem = pipeline.smem;
3026
3027 ggml_metal_encoder_set_pipeline(enc, pipeline);
3028 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3029 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3030 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3031
3032 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3033
3034 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3035
3036 return 1;
3037}
3038
3039int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
3040 ggml_tensor * op = ctx->node(idx);
3041
3042 ggml_metal_library_t lib = ctx->lib;
3043 ggml_metal_encoder_t enc = ctx->enc;
3044
3045 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3046 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3047 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3048 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3049
3050 const int32_t ngrp = ((const int32_t *) op->op_params)[0];
3051
3052 float eps;
3053 memcpy(&eps, op->op_params + 1, sizeof(float));
3054
3055 ggml_metal_kargs_group_norm args = {
3056 /*.ne00 =*/ ne00,
3057 /*.ne01 =*/ ne01,
3058 /*.ne02 =*/ ne02,
3059 /*.nb00 =*/ nb00,
3060 /*.nb01 =*/ nb01,
3061 /*.nb02 =*/ nb02,
3062 /*.ngrp =*/ ngrp,
3063 /*.eps =*/ eps,
3064 };
3065
3066 auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
3067
3068 int nth = 32; // SIMD width
3069 //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3070 // nth *= 2;
3071 //}
3072
3073 //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3074 //nth = std::min(nth, ne00/4);
3075
3076 const size_t smem = pipeline.smem;
3077
3078 ggml_metal_encoder_set_pipeline(enc, pipeline);
3079 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3080 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3081 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3082
3083 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3084
3085 ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1);
3086
3087 return 1;
3088}
3089
3090int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
3091 ggml_tensor * op = ctx->node(idx);
3092
3093 ggml_metal_library_t lib = ctx->lib;
3094 ggml_metal_encoder_t enc = ctx->enc;
3095
3096 const bool use_fusion = ctx->use_fusion;
3097
3098 const int debug_fusion = ctx->debug_fusion;
3099
3100 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3101 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3102 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3103 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3104
3105 float eps;
3106 memcpy(&eps, op->op_params, sizeof(float));
3107
3108 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3109 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3110
3111 ggml_metal_kargs_norm args = {
3112 /*.ne00 =*/ ne00,
3113 /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
3114 /*.nb1 =*/ nb1,
3115 /*.nb2 =*/ nb2,
3116 /*.nb3 =*/ nb3,
3117 /*.eps =*/ eps,
3118 /*.nef1 =*/ { ne01 },
3119 /*.nef2 =*/ { ne02 },
3120 /*.nef3 =*/ { ne03 },
3121 /*.nbf1 =*/ { nb01 },
3122 /*.nbf2 =*/ { nb02 },
3123 /*.nbf3 =*/ { nb03 },
3124 };
3125
3126 ggml_op fops[8];
3127
3128 int n_fuse = 1;
3129
3130 ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
3131
3132 // d[0] = norm(a)
3133 // d[1] = mul(d[0], b)
3134 // d[2] = add(d[1], c)
3135 if (use_fusion) {
3136 fops[0] = op->op;
3137 fops[1] = GGML_OP_MUL;
3138 fops[2] = GGML_OP_ADD;
3139
3140 for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
3141 if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
3142 break;
3143 }
3144
3145 ggml_tensor * f0 = ctx->node(idx + n_fuse);
3146 ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
3147
3148 if (f0 != f1->src[0]) {
3149 break;
3150 }
3151
3152 if (f1->src[1]->ne[0] != op->ne[0]) {
3153 break;
3154 }
3155
3156 if (!ggml_is_contiguous_rows(f1->src[1])) {
3157 break;
3158 }
3159
3160 if (f1->type != GGML_TYPE_F32) {
3161 break;
3162 }
3163
3164 //ctx->fuse_cnt[f1->op]++;
3165
3166 bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
3167
3168 args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
3169 args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
3170 args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
3171
3172 args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
3173 args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
3174 args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
3175 }
3176
3177 ++n_fuse;
3178
3179 if (debug_fusion > 1 && n_fuse > 1) {
3180 if (n_fuse == 2) {
3181 GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
3182 }
3183 if (n_fuse == 3) {
3184 GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
3185 }
3186 }
3187 }
3188
3189 if (n_fuse > 1) {
3190 bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
3191
3192 for (int i = 1; i < n_fuse; ++i) {
3193 if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
3194 ggml_metal_op_concurrency_reset(ctx);
3195
3196 break;
3197 }
3198 }
3199 }
3200
3201 auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
3202
3203 int nth = 32; // SIMD width
3204
3205 while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3206 nth *= 2;
3207 }
3208
3209 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3210 nth = std::min(nth, args.ne00_t);
3211
3212 const size_t smem = pipeline.smem;
3213
3214 ggml_metal_encoder_set_pipeline(enc, pipeline);
3215 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3216 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3217 ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
3218 ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
3219 ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
3220
3221 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3222
3223 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3224
3225 return n_fuse;
3226}
3227
3228int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
3229 ggml_tensor * op = ctx->node(idx);
3230
3231 ggml_metal_library_t lib = ctx->lib;
3232 ggml_metal_encoder_t enc = ctx->enc;
3233
3234 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3235 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3236 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3237 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3238 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3239 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3240
3241 // make sure we have one or more position id(ne10) per token(ne02)
3242 GGML_ASSERT(ne10 % ne02 == 0);
3243 GGML_ASSERT(ne10 >= ne02);
3244
3245 const int nth = std::min(1024, ne00);
3246
3247 const int n_past = ((const int32_t *) op->op_params)[0];
3248 const int n_dims = ((const int32_t *) op->op_params)[1];
3249 //const int mode = ((const int32_t *) op->op_params)[2];
3250 // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
3251 const int n_ctx_orig = ((const int32_t *) op->op_params)[4];
3252
3253 float freq_base;
3254 float freq_scale;
3255 float ext_factor;
3256 float attn_factor;
3257 float beta_fast;
3258 float beta_slow;
3259
3260 memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float));
3261 memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float));
3262 memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float));
3263 memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float));
3264 memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float));
3265 memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float));
3266
3267 // mrope
3268 const int sect_0 = ((const int32_t *) op->op_params)[11];
3269 const int sect_1 = ((const int32_t *) op->op_params)[12];
3270 const int sect_2 = ((const int32_t *) op->op_params)[13];
3271 const int sect_3 = ((const int32_t *) op->op_params)[14];
3272
3273 ggml_metal_kargs_rope args = {
3274 /*.ne00 =*/ ne00,
3275 /*.ne01 =*/ ne01,
3276 /*.ne02 =*/ ne02,
3277 /*.ne03 =*/ ne03,
3278 /*.nb00 =*/ nb00,
3279 /*.nb01 =*/ nb01,
3280 /*.nb02 =*/ nb02,
3281 /*.nb03 =*/ nb03,
3282 /*.ne0 =*/ ne0,
3283 /*.ne1 =*/ ne1,
3284 /*.ne2 =*/ ne2,
3285 /*.ne3 =*/ ne3,
3286 /*.nb0 =*/ nb0,
3287 /*.nb1 =*/ nb1,
3288 /*.nb2 =*/ nb2,
3289 /*.nb3 =*/ nb3,
3290 /*.n_past =*/ n_past,
3291 /*.n_dims =*/ n_dims,
3292 /*.n_ctx_orig =*/ n_ctx_orig,
3293 /*.freq_base =*/ freq_base,
3294 /*.freq_scale =*/ freq_scale,
3295 /*.ext_factor =*/ ext_factor,
3296 /*.attn_factor =*/ attn_factor,
3297 /*.beta_fast =*/ beta_fast,
3298 /*.beta_slow =*/ beta_slow,
3299 /* sect_0 =*/ sect_0,
3300 /* sect_1 =*/ sect_1,
3301 /* sect_2 =*/ sect_2,
3302 /* sect_3 =*/ sect_3,
3303 /* src2 =*/ op->src[2] != nullptr,
3304 };
3305
3306 auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
3307
3308 ggml_metal_encoder_set_pipeline(enc, pipeline);
3309 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3310 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3311 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3312 if (op->src[2]) {
3313 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
3314 } else {
3315 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3);
3316 }
3317 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4);
3318
3319 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
3320
3321 return 1;
3322}
3323
3324int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
3325 ggml_tensor * op = ctx->node(idx);
3326
3327 ggml_metal_library_t lib = ctx->lib;
3328 ggml_metal_encoder_t enc = ctx->enc;
3329
3330 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3331 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3332 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3333 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3334
3335 const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3336 const int32_t s1 = ((const int32_t *)(op->op_params))[1];
3337 const int32_t p0 = ((const int32_t *)(op->op_params))[2];
3338 const int32_t p1 = ((const int32_t *)(op->op_params))[3];
3339 const int32_t d0 = ((const int32_t *)(op->op_params))[4];
3340 const int32_t d1 = ((const int32_t *)(op->op_params))[5];
3341
3342 const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
3343
3344 const int32_t N = op->src[1]->ne[is_2D ? 3 : 2];
3345 const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1];
3346 const int32_t IH = is_2D ? op->src[1]->ne[1] : 1;
3347 const int32_t IW = op->src[1]->ne[0];
3348
3349 const int32_t KH = is_2D ? op->src[0]->ne[1] : 1;
3350 const int32_t KW = op->src[0]->ne[0];
3351
3352 const int32_t OH = is_2D ? op->ne[2] : 1;
3353 const int32_t OW = op->ne[1];
3354
3355 const int32_t CHW = IC * KH * KW;
3356
3357 const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
3358 const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
3359
3360 ggml_metal_kargs_im2col args = {
3361 /*.ofs0 =*/ ofs0,
3362 /*.ofs1 =*/ ofs1,
3363 /*.IW =*/ IW,
3364 /*.IH =*/ IH,
3365 /*.CHW =*/ CHW,
3366 /*.s0 =*/ s0,
3367 /*.s1 =*/ s1,
3368 /*.p0 =*/ p0,
3369 /*.p1 =*/ p1,
3370 /*.d0 =*/ d0,
3371 /*.d1 =*/ d1,
3372 /*.N =*/ N,
3373 /*.KH =*/ KH,
3374 /*.KW =*/ KW,
3375 /*.KHW =*/ KH * KW,
3376 };
3377
3378 auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
3379
3380 GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
3381
3382 const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
3383
3384 ggml_metal_encoder_set_pipeline(enc, pipeline);
3385 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3386 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
3387 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3388
3389 ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
3390
3391 return 1;
3392}
3393
3394int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
3395 ggml_tensor * op = ctx->node(idx);
3396
3397 ggml_metal_library_t lib = ctx->lib;
3398 ggml_metal_encoder_t enc = ctx->enc;
3399
3400 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3401 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3402 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3403 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3404 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3405 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3406
3407 GGML_ASSERT(ggml_is_contiguous(op->src[0]));
3408 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
3409 GGML_ASSERT(op->type == GGML_TYPE_F32);
3410 GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
3411
3412 const int32_t s0 = ((const int32_t *) op->op_params)[0];
3413 const int32_t s1 = ((const int32_t *) op->op_params)[1];
3414 const int32_t p0 = ((const int32_t *) op->op_params)[2];
3415 const int32_t p1 = ((const int32_t *) op->op_params)[3];
3416 const int32_t d0 = ((const int32_t *) op->op_params)[4];
3417 const int32_t d1 = ((const int32_t *) op->op_params)[5];
3418
3419 ggml_metal_kargs_conv_2d args = {
3420 /*.nb00 =*/ nb00,
3421 /*.nb01 =*/ nb01,
3422 /*.nb02 =*/ nb02,
3423 /*.nb03 =*/ nb03,
3424 /*.nb10 =*/ nb10,
3425 /*.nb11 =*/ nb11,
3426 /*.nb12 =*/ nb12,
3427 /*.nb13 =*/ nb13,
3428 /*.nb0 =*/ nb0,
3429 /*.nb1 =*/ nb1,
3430 /*.nb2 =*/ nb2,
3431 /*.nb3 =*/ nb3,
3432 /*.IW =*/ ne10,
3433 /*.IH =*/ ne11,
3434 /*.KW =*/ ne00,
3435 /*.KH =*/ ne01,
3436 /*.IC =*/ ne02,
3437 /*.OC =*/ ne03,
3438 /*.OW =*/ ne0,
3439 /*.OH =*/ ne1,
3440 /*.N =*/ ne3,
3441 /*.s0 =*/ s0,
3442 /*.s1 =*/ s1,
3443 /*.p0 =*/ p0,
3444 /*.p1 =*/ p1,
3445 /*.d0 =*/ d0,
3446 /*.d1 =*/ d1,
3447 };
3448
3449 auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
3450
3451 int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
3452 nth = std::min(nth, 256);
3453 nth = std::max(nth, 1);
3454
3455 const uint64_t n_out = ggml_nelements(op);
3456
3457 uint64_t tg = (n_out + nth - 1)/nth;
3458 tg = std::max<uint64_t>(tg, 1);
3459 tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
3460
3461 ggml_metal_encoder_set_pipeline(enc, pipeline);
3462 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3463 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3464 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3465 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3466
3467 ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
3468
3469 return 1;
3470}
3471
3472int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
3473 ggml_tensor * op = ctx->node(idx);
3474
3475 ggml_metal_library_t lib = ctx->lib;
3476 ggml_metal_encoder_t enc = ctx->enc;
3477
3478 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3479 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3480 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3481 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3482 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3483 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3484
3485 const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3486
3487 const int32_t IC = op->src[1]->ne[1];
3488 const int32_t IL = op->src[1]->ne[0];
3489
3490 const int32_t K = op->src[0]->ne[0];
3491
3492 const int32_t OL = op->ne[0];
3493 const int32_t OC = op->ne[1];
3494
3495 ggml_metal_kargs_conv_transpose_1d args = {
3496 /*.IC =*/ IC,
3497 /*.IL =*/ IL,
3498 /*.K =*/ K,
3499 /*.s0 =*/ s0,
3500 /*.nb0 =*/ nb0,
3501 /*.nb1 =*/ nb1,
3502 };
3503
3504 auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
3505
3506 ggml_metal_encoder_set_pipeline(enc, pipeline);
3507 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3508 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3509 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3510 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3511
3512 ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1);
3513
3514 return 1;
3515}
3516
3517int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
3518 ggml_tensor * op = ctx->node(idx);
3519
3520 ggml_metal_library_t lib = ctx->lib;
3521 ggml_metal_encoder_t enc = ctx->enc;
3522
3523 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3524 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3525 GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
3526 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
3527 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3528 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3529
3530 const int32_t s0 = ((const int32_t *)(op->op_params))[0];
3531
3532 const int32_t IC = op->src[1]->ne[2];
3533 const int32_t IH = op->src[1]->ne[1];
3534 const int32_t IW = op->src[1]->ne[0];
3535
3536 const int32_t KH = op->src[0]->ne[1];
3537 const int32_t KW = op->src[0]->ne[0];
3538
3539 const int32_t OW = op->ne[0];
3540 const int32_t OH = op->ne[1];
3541 const int32_t OC = op->ne[2];
3542
3543 ggml_metal_kargs_conv_transpose_2d args = {
3544 /*.IC =*/ IC,
3545 /*.IH =*/ IH,
3546 /*.IW =*/ IW,
3547 /*.KH =*/ KH,
3548 /*.KW =*/ KW,
3549 /*.OC =*/ OC,
3550 /*.s0 =*/ s0,
3551 /*.nb0 =*/ nb0,
3552 /*.nb1 =*/ nb1,
3553 /*.nb2 =*/ nb2,
3554 };
3555
3556 auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
3557
3558 ggml_metal_encoder_set_pipeline(enc, pipeline);
3559 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3560 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3561 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
3562 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
3563
3564 // Metal requires buffer size to be multiple of 16 bytes
3565 const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16);
3566 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3567
3568 ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1);
3569
3570 return 1;
3571}
3572
3573int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
3574 ggml_tensor * op = ctx->node(idx);
3575
3576 ggml_metal_library_t lib = ctx->lib;
3577 ggml_metal_encoder_t enc = ctx->enc;
3578
3579 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3580 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3581 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3582 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3583
3584 const float sf0 = (float)ne0/op->src[0]->ne[0];
3585 const float sf1 = (float)ne1/op->src[0]->ne[1];
3586 const float sf2 = (float)ne2/op->src[0]->ne[2];
3587 const float sf3 = (float)ne3/op->src[0]->ne[3];
3588
3589 ggml_metal_kargs_upscale args = {
3590 /*.ne00 =*/ ne00,
3591 /*.ne01 =*/ ne01,
3592 /*.ne02 =*/ ne02,
3593 /*.ne03 =*/ ne03,
3594 /*.nb00 =*/ nb00,
3595 /*.nb01 =*/ nb01,
3596 /*.nb02 =*/ nb02,
3597 /*.nb03 =*/ nb03,
3598 /*.ne0 =*/ ne0,
3599 /*.ne1 =*/ ne1,
3600 /*.ne2 =*/ ne2,
3601 /*.ne3 =*/ ne3,
3602 /*.nb0 =*/ nb0,
3603 /*.nb1 =*/ nb1,
3604 /*.nb2 =*/ nb2,
3605 /*.nb3 =*/ nb3,
3606 /*.sf0 =*/ sf0,
3607 /*.sf1 =*/ sf1,
3608 /*.sf2 =*/ sf2,
3609 /*.sf3 =*/ sf3
3610 };
3611
3612 auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
3613
3614 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3615
3616 ggml_metal_encoder_set_pipeline(enc, pipeline);
3617 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3618 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3619 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3620
3621 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3622
3623 return 1;
3624}
3625
3626int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
3627 ggml_tensor * op = ctx->node(idx);
3628
3629 ggml_metal_library_t lib = ctx->lib;
3630 ggml_metal_encoder_t enc = ctx->enc;
3631
3632 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3633 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3634 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3635 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3636
3637 ggml_metal_kargs_pad args = {
3638 /*.ne00 =*/ ne00,
3639 /*.ne01 =*/ ne01,
3640 /*.ne02 =*/ ne02,
3641 /*.ne03 =*/ ne03,
3642 /*.nb00 =*/ nb00,
3643 /*.nb01 =*/ nb01,
3644 /*.nb02 =*/ nb02,
3645 /*.nb03 =*/ nb03,
3646 /*.ne0 =*/ ne0,
3647 /*.ne1 =*/ ne1,
3648 /*.ne2 =*/ ne2,
3649 /*.ne3 =*/ ne3,
3650 /*.nb0 =*/ nb0,
3651 /*.nb1 =*/ nb1,
3652 /*.nb2 =*/ nb2,
3653 /*.nb3 =*/ nb3
3654 };
3655
3656 auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
3657
3658 const int nth = std::min(1024, ne0);
3659
3660 ggml_metal_encoder_set_pipeline(enc, pipeline);
3661 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3662 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3663 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3664
3665 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3666
3667 return 1;
3668}
3669
3670int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
3671 ggml_tensor * op = ctx->node(idx);
3672
3673 ggml_metal_library_t lib = ctx->lib;
3674 ggml_metal_encoder_t enc = ctx->enc;
3675
3676 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3677 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3678 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3679 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3680
3681 ggml_metal_kargs_pad_reflect_1d args = {
3682 /*.ne00 =*/ ne00,
3683 /*.ne01 =*/ ne01,
3684 /*.ne02 =*/ ne02,
3685 /*.ne03 =*/ ne03,
3686 /*.nb00 =*/ nb00,
3687 /*.nb01 =*/ nb01,
3688 /*.nb02 =*/ nb02,
3689 /*.nb03 =*/ nb03,
3690 /*.ne0 =*/ ne0,
3691 /*.ne1 =*/ ne1,
3692 /*.ne2 =*/ ne2,
3693 /*.ne3 =*/ ne3,
3694 /*.nb0 =*/ nb0,
3695 /*.nb1 =*/ nb1,
3696 /*.nb2 =*/ nb2,
3697 /*.nb3 =*/ nb3,
3698 /*.p0 =*/ ((const int32_t *)(op->op_params))[0],
3699 /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
3700 };
3701
3702 auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
3703
3704 const int nth = std::min(1024, ne0);
3705
3706 ggml_metal_encoder_set_pipeline(enc, pipeline);
3707 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3708 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3709 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3710
3711 ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
3712
3713 return 1;
3714}
3715
3716int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
3717 ggml_tensor * op = ctx->node(idx);
3718
3719 ggml_metal_library_t lib = ctx->lib;
3720 ggml_metal_encoder_t enc = ctx->enc;
3721
3722 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3723 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3724
3725 float start;
3726 float step;
3727
3728 memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float));
3729 memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float));
3730
3731 ggml_metal_kargs_arange args = {
3732 /*.ne0 =*/ ne0,
3733 /*.start =*/ start,
3734 /*.step =*/ step
3735 };
3736
3737 const int nth = std::min(1024, ne0);
3738
3739 auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
3740
3741 ggml_metal_encoder_set_pipeline(enc, pipeline);
3742 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3743 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
3744
3745 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);
3746
3747 return 1;
3748}
3749
3750int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
3751 ggml_tensor * op = ctx->node(idx);
3752
3753 ggml_metal_library_t lib = ctx->lib;
3754 ggml_metal_encoder_t enc = ctx->enc;
3755
3756 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3757 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3758 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3759 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3760
3761 const int dim = op->op_params[0];
3762 const int max_period = op->op_params[1];
3763
3764 ggml_metal_kargs_timestep_embedding args = {
3765 /*.nb1 =*/ nb1,
3766 /*.dim =*/ dim,
3767 /*.max_period =*/ max_period,
3768 };
3769
3770 auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
3771
3772 const int nth = std::max(1, std::min(1024, dim/2));
3773
3774 ggml_metal_encoder_set_pipeline(enc, pipeline);
3775 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3776 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3777 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3778
3779 ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1);
3780
3781 return 1;
3782}
3783
3784int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
3785 ggml_tensor * op = ctx->node(idx);
3786
3787 ggml_metal_library_t lib = ctx->lib;
3788 ggml_metal_encoder_t enc = ctx->enc;
3789
3790 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3791 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3792 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3793 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3794
3795 ggml_metal_kargs_argmax args = {
3796 /*.ne00 = */ ne00,
3797 /*.nb01 = */ nb01,
3798 };
3799
3800 auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
3801
3802 const int64_t nrows = ggml_nrows(op->src[0]);
3803
3804 int nth = 32; // SIMD width
3805 while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
3806 nth *= 2;
3807 }
3808
3809 const size_t smem = pipeline.smem;
3810
3811 ggml_metal_encoder_set_pipeline(enc, pipeline);
3812 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3813 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
3814 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
3815
3816 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3817
3818 ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
3819
3820 return 1;
3821}
3822
3823int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
3824 ggml_tensor * op = ctx->node(idx);
3825
3826 ggml_metal_library_t lib = ctx->lib;
3827 ggml_metal_encoder_t enc = ctx->enc;
3828
3829 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3830
3831 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3832 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3833 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3834 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3835
3836 auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
3837
3838 // bitonic sort requires the number of elements to be power of 2
3839 int nth = 1;
3840 while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3841 nth *= 2;
3842 }
3843
3844 const int npr = (ne00 + nth - 1)/nth;
3845
3846 // Metal kernels require the buffer size to be multiple of 16 bytes
3847 // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
3848 const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3849
3850 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3851 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3852
3853 ggml_metal_buffer_id bid_tmp = bid_dst;
3854 bid_tmp.offs += ggml_nbytes(op);
3855
3856 if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3857 std::swap(bid_dst, bid_tmp);
3858 }
3859
3860 ggml_metal_kargs_argsort args = {
3861 /*.ne00 =*/ ne00,
3862 /*.ne01 =*/ ne01,
3863 /*.ne02 =*/ ne02,
3864 /*.ne03 =*/ ne03,
3865 /*.nb00 =*/ nb00,
3866 /*.nb01 =*/ nb01,
3867 /*.nb02 =*/ nb02,
3868 /*.nb03 =*/ nb03,
3869 /*.ne0 =*/ ne0,
3870 /*.ne1 =*/ ne1,
3871 /*.ne2 =*/ ne2,
3872 /*.ne3 =*/ ne3,
3873 /*.top_k =*/ nth,
3874 };
3875
3876 ggml_metal_encoder_set_pipeline(enc, pipeline);
3877 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3878 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3879 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3880
3881 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3882
3883 ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3884
3885 auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
3886
3887 int len = nth;
3888
3889 while (len < ne00) {
3890 ggml_metal_op_concurrency_reset(ctx);
3891
3892 ggml_metal_kargs_argsort_merge args_merge = {
3893 /*.ne00 =*/ ne00,
3894 /*.ne01 =*/ ne01,
3895 /*.ne02 =*/ ne02,
3896 /*.ne03 =*/ ne03,
3897 /*.nb00 =*/ nb00,
3898 /*.nb01 =*/ nb01,
3899 /*.nb02 =*/ nb02,
3900 /*.nb03 =*/ nb03,
3901 /*.ne0 =*/ ne0,
3902 /*.ne1 =*/ ne1,
3903 /*.ne2 =*/ ne2,
3904 /*.ne3 =*/ ne3,
3905 /*.top_k =*/ ne00,
3906 /*.len =*/ len,
3907 };
3908
3909 // merges per row
3910 const int nm = (ne00 + 2*len - 1) / (2*len);
3911
3912 const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge));
3913
3914 ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
3915 ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
3916 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3917 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3918 ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
3919
3920 ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
3921
3922 std::swap(bid_dst, bid_tmp);
3923
3924 len <<= 1;
3925 }
3926
3927 return 1;
3928}
3929
3930int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
3931 ggml_tensor * op = ctx->node(idx);
3932
3933 ggml_metal_library_t lib = ctx->lib;
3934 ggml_metal_encoder_t enc = ctx->enc;
3935
3936 GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
3937
3938 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3939 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3940 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3941 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
3942
3943 auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
3944
3945 // bitonic sort requires the number of elements to be power of 2
3946 int nth = 1;
3947 while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
3948 nth *= 2;
3949 }
3950
3951 // blocks per row
3952 const int npr = (ne00 + nth - 1)/nth;
3953
3954 const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16);
3955
3956 ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
3957 ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
3958
3959 ggml_metal_buffer_id bid_tmp = bid_dst;
3960 bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]);
3961
3962 if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) {
3963 std::swap(bid_dst, bid_tmp);
3964 }
3965
3966 const int top_k = ne0;
3967
3968 ggml_metal_kargs_argsort args = {
3969 /*.ne00 =*/ ne00,
3970 /*.ne01 =*/ ne01,
3971 /*.ne02 =*/ ne02,
3972 /*.ne03 =*/ ne03,
3973 /*.nb00 =*/ nb00,
3974 /*.nb01 =*/ nb01,
3975 /*.nb02 =*/ nb02,
3976 /*.nb03 =*/ nb03,
3977 /*.ne0 =*/ ne0,
3978 /*.ne1 =*/ ne1,
3979 /*.ne2 =*/ ne2,
3980 /*.ne3 =*/ ne3,
3981 /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices
3982 };
3983
3984 if (npr > 1) {
3985 args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k);
3986 }
3987
3988 ggml_metal_encoder_set_pipeline(enc, pipeline);
3989 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
3990 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
3991 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
3992
3993 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
3994
3995 ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
3996
3997 auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
3998
3999 int len = args.top_k;
4000
4001 while (len < args.ne0) {
4002 ggml_metal_op_concurrency_reset(ctx);
4003
4004 // merges per row
4005 const int nm = (args.ne0 + 2*len - 1) / (2*len);
4006
4007 const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)));
4008
4009 ggml_metal_kargs_argsort_merge args_merge = {
4010 /*.ne00 =*/ ne00,
4011 /*.ne01 =*/ ne01,
4012 /*.ne02 =*/ ne02,
4013 /*.ne03 =*/ ne03,
4014 /*.nb00 =*/ nb00,
4015 /*.nb01 =*/ nb01,
4016 /*.nb02 =*/ nb02,
4017 /*.nb03 =*/ nb03,
4018 /*.ne0 =*/ args.ne0,
4019 /*.ne1 =*/ ne1,
4020 /*.ne2 =*/ ne2,
4021 /*.ne3 =*/ ne3,
4022 /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements
4023 /*.len =*/ len,
4024 };
4025
4026 ggml_metal_encoder_set_pipeline(enc, pipeline_merge);
4027 ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0);
4028 ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
4029 ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
4030 ggml_metal_encoder_set_buffer (enc, bid_tmp, 3);
4031
4032 ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1);
4033
4034 std::swap(bid_dst, bid_tmp);
4035
4036 len <<= 1;
4037 }
4038
4039 return 1;
4040}
4041
4042int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
4043 ggml_tensor * op = ctx->node(idx);
4044
4045 ggml_metal_library_t lib = ctx->lib;
4046 ggml_metal_encoder_t enc = ctx->enc;
4047
4048 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4049 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4050 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4051 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4052
4053 ggml_metal_kargs_tri args = {
4054 /*.ne00 =*/ ne00,
4055 /*.ne01 =*/ ne01,
4056 /*.ne02 =*/ ne02,
4057 /*.ne03 =*/ ne03,
4058 /*.nb00 =*/ nb00,
4059 /*.nb01 =*/ nb01,
4060 /*.nb02 =*/ nb02,
4061 /*.nb03 =*/ nb03,
4062 /*.ne0 =*/ ne0,
4063 /*.ne1 =*/ ne1,
4064 /*.ne2 =*/ ne2,
4065 /*.ne3 =*/ ne3,
4066 /*.nb0 =*/ nb0,
4067 /*.nb1 =*/ nb1,
4068 /*.nb2 =*/ nb2,
4069 /*.nb3 =*/ nb3,
4070 };
4071
4072 auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op);
4073
4074 int nth = 32; // SIMD width
4075
4076 while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
4077 nth *= 2;
4078 }
4079
4080 nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4081 nth = std::min(nth, ne00);
4082
4083 ggml_metal_encoder_set_pipeline(enc, pipeline);
4084 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
4085 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4086 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
4087
4088 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4089
4090 return 1;
4091}
4092
4093int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
4094 ggml_tensor * op = ctx->node(idx);
4095
4096 ggml_metal_library_t lib = ctx->lib;
4097 ggml_metal_encoder_t enc = ctx->enc;
4098
4099 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4100 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4101 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4102 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4103
4104 auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
4105
4106 const int64_t np = ggml_nelements(op->src[0]);
4107 ggml_metal_kargs_opt_step_adamw args = {
4108 /*.np =*/ np,
4109 };
4110
4111 int ida = 0;
4112
4113 ggml_metal_encoder_set_pipeline(enc, pipeline);
4114 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4115 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4116 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4117 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4118 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
4119 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
4120
4121 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4122 const int64_t n = (np + nth - 1) / nth;
4123
4124 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4125
4126 return 1;
4127}
4128
4129int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
4130 ggml_tensor * op = ctx->node(idx);
4131
4132 ggml_metal_library_t lib = ctx->lib;
4133 ggml_metal_encoder_t enc = ctx->enc;
4134
4135 GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
4136 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4137 GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
4138 GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
4139
4140 auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
4141
4142 const int64_t np = ggml_nelements(op->src[0]);
4143 ggml_metal_kargs_opt_step_sgd args = {
4144 /*.np =*/ np,
4145 };
4146
4147 int ida = 0;
4148
4149 ggml_metal_encoder_set_pipeline(enc, pipeline);
4150 ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
4151 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
4152 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
4153 ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
4154
4155 const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
4156 const int64_t n = (np + nth - 1) / nth;
4157
4158 ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
4159
4160 return 1;
4161}
4162
4163int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
4164 ggml_tensor * op = ctx->node(idx);
4165
4166 ggml_metal_library_t lib = ctx->lib;
4167 ggml_metal_encoder_t enc = ctx->enc;
4168
4169 GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
4170 GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
4171 GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
4172
4173 {
4174 ggml_metal_kargs_memset args = { /*.val =*/ 0 };
4175
4176 auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
4177
4178 ggml_metal_encoder_set_pipeline(enc, pipeline);
4179 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4180 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
4181
4182 ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
4183 }
4184
4185 ggml_metal_op_concurrency_reset(ctx);
4186
4187 {
4188 ggml_metal_kargs_count_equal args = {
4189 /*.ne00 =*/ ne00,
4190 /*.ne01 =*/ ne01,
4191 /*.ne02 =*/ ne02,
4192 /*.ne03 =*/ ne03,
4193 /*.nb00 =*/ nb00,
4194 /*.nb01 =*/ nb01,
4195 /*.nb02 =*/ nb02,
4196 /*.nb03 =*/ nb03,
4197 /*.nb10 =*/ nb10,
4198 /*.nb11 =*/ nb11,
4199 /*.nb12 =*/ nb12,
4200 /*.nb13 =*/ nb13,
4201 };
4202
4203 auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
4204
4205 const size_t smem = pipeline.smem;
4206
4207 const int nth = 32*pipeline.nsg;
4208
4209 GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
4210
4211 ggml_metal_encoder_set_pipeline(enc, pipeline);
4212 ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
4213 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
4214 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
4215 ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
4216
4217 ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
4218 ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
4219 }
4220
4221 return 1;
4222}
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h
new file mode 100644
index 0000000..29456d7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h
@@ -0,0 +1,93 @@
1#pragma once
2
3#include "ggml-metal-device.h"
4
5#ifdef __cplusplus
6extern "C" {
7#endif
8
9typedef struct ggml_metal_op * ggml_metal_op_t;
10
11ggml_metal_op_t ggml_metal_op_init(
12 ggml_metal_device_t dev,
13 ggml_metal_cmd_buf_t cmd_buf,
14 struct ggml_cgraph * gf,
15 int idx_start,
16 int idx_end,
17 bool use_fusion,
18 bool use_concurrency,
19 bool use_capture,
20 int debug_graph,
21 int debug_fusion);
22
23void ggml_metal_op_free(ggml_metal_op_t ctx);
24
25int ggml_metal_op_n_nodes(ggml_metal_op_t ctx);
26
27int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
28
29//
30// available ops:
31//
32
33// tokens per expert
34size_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op);
35
36// id map [n_tokens, n_expert]
37size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
38
39// return true if we should use the FA vector kernel for this op
40bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
41
42size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
43size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
44size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
45
46int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
47int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
48int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
49int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
50int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
51int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
52int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
53int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
54int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
55int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
56int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx);
57int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
58int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
59int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
60int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
61int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
62int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
63int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
64int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
65int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
66int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
67int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
68int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
69int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
70int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
71int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
72int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
73int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
74int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
75int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
76int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
77int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
78int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
79int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
80int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
81int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
82int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
83int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
84int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
85int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
86int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
87int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
88int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
89int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
90
91#ifdef __cplusplus
92}
93#endif
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp
new file mode 100644
index 0000000..1c70536
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp
@@ -0,0 +1,937 @@
1#include "ggml-metal.h"
2
3#include "ggml-impl.h"
4#include "ggml-backend-impl.h"
5
6#include "ggml-metal-device.h"
7#include "ggml-metal-context.h"
8#include "ggml-metal-ops.h"
9
10#include <mutex>
11#include <string>
12
13#define GGML_METAL_NAME "MTL"
14#define GGML_METAL_MAX_DEVICES 16
15
16// number of Metal devices
17// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices
18static int g_devices = 1;
19
20////////////////////////////////////////////////////////////////////////////////
21// backend interface
22////////////////////////////////////////////////////////////////////////////////
23
24// shared buffer
25
26static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) {
27 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
28
29 GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
30
31 ggml_metal_buffer_free(ctx);
32}
33
34static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) {
35 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
36
37 GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
38
39 return ggml_metal_buffer_get_base(ctx);
40}
41
42static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
43 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
44
45 GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
46
47 ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
48}
49
50static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
51 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
52
53 GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
54
55 ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
56}
57
58static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
59 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
60
61 GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
62
63 ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
64}
65
66static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
67 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
68
69 GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
70
71 GGML_UNUSED(buffer);
72 GGML_UNUSED(src);
73 GGML_UNUSED(dst);
74
75 return false;
76}
77
78static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) {
79 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
80
81 GGML_ASSERT(ggml_metal_buffer_is_shared(ctx));
82
83 ggml_metal_buffer_clear(ctx, value);
84}
85
86static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = {
87 /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer,
88 /* .get_base = */ ggml_backend_metal_buffer_shared_get_base,
89 /* .init_tensor = */ NULL,
90 /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor,
91 /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor,
92 /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor,
93 /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor,
94 /* .clear = */ ggml_backend_metal_buffer_shared_clear,
95 /* .reset = */ NULL,
96};
97
98// private buffer
99
100static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) {
101 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
102
103 GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
104
105 ggml_metal_buffer_free(ctx);
106}
107
108static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) {
109 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
110
111 GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
112
113 return ggml_metal_buffer_get_base(ctx);
114}
115
116static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
117 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
118
119 GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
120
121 ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size);
122}
123
124static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
125 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
126
127 GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
128
129 ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size);
130}
131
132static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
133 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
134
135 GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
136
137 ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size);
138}
139
140static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
141 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
142
143 GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
144
145 GGML_UNUSED(buffer);
146 GGML_UNUSED(src);
147 GGML_UNUSED(dst);
148
149 return false;
150}
151
152static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) {
153 ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context;
154
155 GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx));
156
157 ggml_metal_buffer_clear(ctx, value);
158}
159
160static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = {
161 /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer,
162 /* .get_base = */ ggml_backend_metal_buffer_private_get_base,
163 /* .init_tensor = */ NULL,
164 /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor,
165 /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor,
166 /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor,
167 /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor,
168 /* .clear = */ ggml_backend_metal_buffer_private_clear,
169 /* .reset = */ NULL,
170};
171
172static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) {
173 return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer ||
174 buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer;
175}
176
177//
178// buffer types
179//
180
181struct ggml_backend_metal_buffer_type {
182 int device;
183 std::string name;
184};
185
186struct ggml_backend_metal_buffer_type_deleter {
187 void operator()(ggml_backend_metal_buffer_type * ctx) const {
188 delete ctx;
189 }
190};
191
192typedef std::unique_ptr<ggml_backend_metal_buffer_type, ggml_backend_metal_buffer_type_deleter> ggml_backend_metal_buffer_type_ptr;
193
194// common method for allocating shread or private Metal buffers
195static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) {
196 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
197 ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared);
198
199 ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res)
200 ? ggml_backend_metal_buffer_shared_i
201 : ggml_backend_metal_buffer_private_i;
202
203 return ggml_backend_buffer_init(buft, buf_i, res, size);
204}
205
206static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
207 size_t res = ggml_nbytes(tensor);
208
209 // some operations require additional memory for fleeting data:
210 switch (tensor->op) {
211 case GGML_OP_MUL_MAT_ID:
212 {
213 res += ggml_metal_op_mul_mat_id_extra_tpe(tensor);
214 res += ggml_metal_op_mul_mat_id_extra_ids(tensor);
215 } break;
216 case GGML_OP_FLASH_ATTN_EXT:
217 {
218 res += ggml_metal_op_flash_attn_ext_extra_pad(tensor);
219 res += ggml_metal_op_flash_attn_ext_extra_blk(tensor);
220 res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor);
221 } break;
222 case GGML_OP_CUMSUM:
223 case GGML_OP_ARGSORT:
224 {
225 res *= 2;
226 } break;
227 case GGML_OP_TOP_K:
228 {
229 res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]);
230 } break;
231 default:
232 break;
233 }
234
235 return res;
236
237 GGML_UNUSED(buft);
238}
239
240// default (shared) buffer type
241
242static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
243 ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
244
245 return ctx->name.c_str();
246}
247
248static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
249 return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
250}
251
252static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) {
253 return 32;
254
255 GGML_UNUSED(buft);
256}
257
258static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) {
259 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
260
261 return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
262}
263
264static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
265 return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
266}
267
268static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
269 return false;
270
271 GGML_UNUSED(buft);
272}
273
274static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) {
275 static std::mutex mutex;
276 std::lock_guard<std::mutex> lock(mutex);
277
278 static std::vector<ggml_backend_buffer_type> bufts;
279 static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;
280
281 static bool initialized = false;
282 if (!initialized) {
283 bufts.reserve(g_devices);
284 ctxs.reserve(g_devices);
285
286 for (int i = 0; i < g_devices; ++i) {
287 ggml_backend_metal_buffer_type * raw_ctx =
288 new ggml_backend_metal_buffer_type {
289 /* .device = */ i,
290 /* .name = */ GGML_METAL_NAME + std::to_string(i),
291 };
292 ctxs.emplace_back(raw_ctx);
293
294 ggml_backend_buffer_type buft = {
295 /* .iface = */ {
296 /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name,
297 /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
298 /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
299 /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
300 /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
301 /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
302 },
303 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
304 /* .context = */ raw_ctx,
305 };
306
307 bufts.emplace_back(buft);
308 }
309
310 initialized = true;
311 }
312
313 return &bufts[device];
314}
315
316// default (private) buffer type
317
318static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) {
319 ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
320
321 return ctx->name.c_str();
322}
323
324static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
325 return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false);
326}
327
328static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) {
329 return 32;
330
331 GGML_UNUSED(buft);
332}
333
334static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) {
335 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
336
337 return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
338}
339
340static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
341 return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
342}
343
344static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
345 return false;
346
347 GGML_UNUSED(buft);
348}
349
350static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) {
351 static std::mutex mutex;
352 std::lock_guard<std::mutex> lock(mutex);
353
354 static std::vector<ggml_backend_buffer_type> bufts;
355 static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;
356
357 static bool initialized = false;
358 if (!initialized) {
359 bufts.reserve(g_devices);
360 ctxs.reserve(g_devices);
361
362 for (int i = 0; i < g_devices; ++i) {
363 ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
364 /* .device = */ i,
365 /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Private"
366 };
367 ctxs.emplace_back(raw_ctx);
368
369 ggml_backend_buffer_type buft = {
370 /* .iface = */ {
371 /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name,
372 /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
373 /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
374 /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
375 /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
376 /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
377 },
378 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
379 /* .context = */ raw_ctx,
380 };
381
382 bufts.emplace_back(buft);
383 }
384
385 initialized = true;
386 }
387
388 return &bufts[device];
389}
390
391// mapped buffer type
392
393static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) {
394 ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context;
395
396 return ctx->name.c_str();
397}
398
399static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
400 // for mapped buffers, prefer shared memory
401 return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true);
402}
403
404static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) {
405 return 32;
406
407 GGML_UNUSED(buft);
408}
409
410static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) {
411 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context;
412
413 return ggml_metal_device_get_props(ctx_dev)->max_buffer_size;
414}
415
416static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
417 return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
418}
419
420static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
421 return false;
422
423 GGML_UNUSED(buft);
424}
425
426static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) {
427 static std::mutex mutex;
428 std::lock_guard<std::mutex> lock(mutex);
429
430 static std::vector<ggml_backend_buffer_type> bufts;
431 static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs;
432
433 static bool initialized = false;
434 if (!initialized) {
435 bufts.reserve(g_devices);
436 ctxs.reserve(g_devices);
437
438 for (int i = 0; i < g_devices; ++i) {
439 ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{
440 /* .device = */ i,
441 /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped"
442 };
443 ctxs.emplace_back(raw_ctx);
444
445 // note: not obvious, but this buffer type still needs to implement .alloc_buffer:
446 // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099
447 ggml_backend_buffer_type buft = {
448 /* .iface = */ {
449 /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name,
450 /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
451 /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
452 /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
453 /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
454 /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
455 },
456 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i),
457 /* .context = */ raw_ctx,
458 };
459
460 bufts.emplace_back(buft);
461 }
462
463 initialized = true;
464 }
465
466 return &bufts[device];
467}
468
469// backend
470
471static const char * ggml_backend_metal_name(ggml_backend_t backend) {
472 ggml_metal_t ctx = (ggml_metal_t)backend->context;
473
474 return ggml_metal_get_name(ctx);
475}
476
477static void ggml_backend_metal_free(ggml_backend_t backend) {
478 ggml_metal_t ctx = (ggml_metal_t)backend->context;
479
480 // wait for any ongoing async operations to finish
481 ggml_metal_synchronize(ctx);
482
483 ggml_metal_free(ctx);
484
485 free(backend);
486}
487
488static void ggml_backend_metal_synchronize(ggml_backend_t backend) {
489 ggml_metal_t ctx = (ggml_metal_t)backend->context;
490
491 ggml_metal_synchronize(ctx);
492}
493
494static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
495 ggml_metal_t ctx = (ggml_metal_t)backend->context;
496
497 ggml_metal_set_tensor_async(ctx, tensor, data, offset, size);
498}
499
500static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
501 ggml_metal_t ctx = (ggml_metal_t)backend->context;
502
503 ggml_metal_get_tensor_async(ctx, tensor, data, offset, size);
504}
505
506static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
507 if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) {
508 return false;
509 }
510
511 if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) {
512 return false;
513 }
514
515 ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context;
516 ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context;
517
518 //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
519 //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
520
521 //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context;
522 //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context;
523
524 return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst);
525}
526
527static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
528 ggml_metal_t ctx = (ggml_metal_t)backend->context;
529
530 return ggml_metal_graph_compute(ctx, cgraph);
531}
532
533static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
534 ggml_metal_t ctx = (ggml_metal_t)backend->context;
535 ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
536
537 ggml_metal_event_record(ctx, ev);
538}
539
540static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
541 ggml_metal_t ctx = (ggml_metal_t)backend->context;
542 ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
543
544 ggml_metal_event_wait(ctx, ev);
545}
546
547static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
548 ggml_metal_t ctx = (ggml_metal_t)backend->context;
549
550 ggml_metal_graph_optimize(ctx, cgraph);
551}
552
553static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
554 GGML_ASSERT(ggml_backend_is_metal(backend));
555
556 ggml_metal_t ctx = (ggml_metal_t)backend->context;
557
558 ggml_metal_set_n_cb(ctx, n_cb);
559}
560
561static ggml_backend_i ggml_backend_metal_i = {
562 /* .get_name = */ ggml_backend_metal_name,
563 /* .free = */ ggml_backend_metal_free,
564 /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
565 /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
566 /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups
567 /* .synchronize = */ ggml_backend_metal_synchronize,
568 /* .graph_plan_create = */ NULL,
569 /* .graph_plan_free = */ NULL,
570 /* .graph_plan_update = */ NULL,
571 /* .graph_plan_compute = */ NULL,
572 /* .graph_compute = */ ggml_backend_metal_graph_compute,
573 /* .event_record = */ ggml_backend_metal_event_record,
574 /* .event_wait = */ ggml_backend_metal_event_wait,
575 /* .graph_optimize = */ ggml_backend_metal_graph_optimize,
576};
577
578static ggml_guid_t ggml_backend_metal_guid(void) {
579 static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 };
580 return &guid;
581}
582
583ggml_backend_t ggml_backend_metal_init(void) {
584 ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
585 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
586
587 ggml_metal_t ctx = ggml_metal_init(ctx_dev);
588 if (ctx == NULL) {
589 GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
590 return NULL;
591 }
592
593 ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
594
595 *backend = {
596 /* .guid = */ ggml_backend_metal_guid(),
597 /* .interface = */ ggml_backend_metal_i,
598 /* .device = */ dev,
599 /* .context = */ ctx,
600 };
601
602 ggml_backend_metal_set_n_cb(backend, 1);
603
604 return backend;
605}
606
607bool ggml_backend_is_metal(ggml_backend_t backend) {
608 return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
609}
610
611void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
612 GGML_ASSERT(ggml_backend_is_metal(backend));
613
614 ggml_metal_t ctx = (ggml_metal_t)backend->context;
615
616 ggml_metal_set_abort_callback(ctx, abort_callback, user_data);
617}
618
619bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
620 GGML_ASSERT(ggml_backend_is_metal(backend));
621
622 ggml_metal_t ctx = (ggml_metal_t)backend->context;
623
624 return ggml_metal_supports_family(ctx, family);
625}
626
627void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
628 GGML_ASSERT(ggml_backend_is_metal(backend));
629
630 ggml_metal_t ctx = (ggml_metal_t)backend->context;
631
632 ggml_metal_capture_next_compute(ctx);
633}
634
635// backend device
636
637static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
638 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
639
640 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
641
642 return props_dev->name;
643}
644
645static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
646 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
647
648 return ggml_metal_device_get_props(ctx_dev)->desc;
649}
650
651static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
652 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
653
654 ggml_metal_device_get_memory(ctx_dev, free, total);
655}
656
657static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
658 return GGML_BACKEND_DEVICE_TYPE_GPU;
659
660 GGML_UNUSED(dev);
661}
662
663static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
664 props->name = ggml_backend_metal_device_get_name(dev);
665 props->description = ggml_backend_metal_device_get_description(dev);
666 props->type = ggml_backend_metal_device_get_type(dev);
667
668 ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
669
670 props->caps = {
671 /* .async = */ true,
672 /* .host_buffer = */ false,
673 /* .buffer_from_host_ptr = */ true,
674 /* .events = */ true,
675 };
676}
677
678static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) {
679 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
680
681 ggml_metal_t ctx = ggml_metal_init(ctx_dev);
682 if (ctx == NULL) {
683 GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
684 return NULL;
685 }
686
687 ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend));
688
689 *backend = {
690 /* .guid = */ ggml_backend_metal_guid(),
691 /* .interface = */ ggml_backend_metal_i,
692 /* .device = */ dev,
693 /* .context = */ ctx,
694 };
695
696 ggml_backend_metal_set_n_cb(backend, 1);
697
698 return backend;
699
700 GGML_UNUSED(params);
701}
702
703static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
704 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
705
706 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
707
708 return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device);
709}
710
711static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
712 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
713
714 ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size);
715
716 const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev);
717
718 return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size);
719}
720
721static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
722 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
723
724 return ggml_metal_device_supports_op(ctx_dev, op);
725}
726
727static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
728 return
729 buft->device == dev && (
730 buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name ||
731 buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name ||
732 buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name);
733
734 GGML_UNUSED(dev);
735}
736
737static int64_t get_op_batch_size(const ggml_tensor * op) {
738 switch (op->op) {
739 case GGML_OP_MUL_MAT:
740 return op->ne[1];
741 case GGML_OP_MUL_MAT_ID:
742 return op->ne[2];
743 default:
744 return ggml_nrows(op);
745 }
746}
747
748static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
749 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
750
751 return (op->op == GGML_OP_MUL_MAT ||
752 op->op == GGML_OP_MUL_MAT_ID) &&
753 get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size;
754}
755
756static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) {
757 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
758
759 ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev);
760 GGML_ASSERT(event);
761
762 ggml_backend_event_t ev = new ggml_backend_event {
763 /* .device = */ dev,
764 /* .context = */ event,
765 };
766
767 return ev;
768}
769
770static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
771 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
772
773 ggml_metal_event_t ev = (ggml_metal_event_t)event->context;
774
775 ggml_metal_device_event_free(ctx_dev, ev);
776
777 delete event;
778}
779
780static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
781 ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context;
782
783 ggml_metal_event_t evt = (ggml_metal_event_t)event->context;
784
785 ggml_metal_device_event_synchronize(ctx_dev, evt);
786}
787
788static ggml_backend_device_i ggml_backend_metal_device_i = {
789 /* .get_name = */ ggml_backend_metal_device_get_name,
790 /* .get_description = */ ggml_backend_metal_device_get_description,
791 /* .get_memory = */ ggml_backend_metal_device_get_memory,
792 /* .get_type = */ ggml_backend_metal_device_get_type,
793 /* .get_props = */ ggml_backend_metal_device_get_props,
794 /* .init_backend = */ ggml_backend_metal_device_init_backend,
795 /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type,
796 /* .get_host_buffer_type = */ NULL,
797 /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped,
798 /* .supports_op = */ ggml_backend_metal_device_supports_op,
799 /* .supports_buft = */ ggml_backend_metal_device_supports_buft,
800 /* .offload_op = */ ggml_backend_metal_device_offload_op,
801 /* .event_new = */ ggml_backend_metal_device_event_new,
802 /* .event_free = */ ggml_backend_metal_device_event_free,
803 /* .event_synchronize = */ ggml_backend_metal_device_event_synchronize,
804};
805
806// backend registry
807
808struct ggml_backend_metal_reg {
809 std::vector<ggml_backend_dev_t> devices;
810};
811
812typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t;
813
814static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) {
815 ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg;
816
817 return ctx;
818}
819
820static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) {
821 delete ctx;
822}
823
824struct ggml_backend_metal_reg_deleter {
825 void operator()(ggml_backend_metal_reg_t ctx) {
826 ggml_backend_metal_reg_free(ctx);
827 }
828};
829
830typedef std::unique_ptr<struct ggml_backend_metal_reg, ggml_backend_metal_reg_deleter> ggml_backend_metal_reg_ptr;
831
832static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
833 return GGML_METAL_NAME;
834
835 GGML_UNUSED(reg);
836}
837
838static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
839 ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
840 return ctx->devices.size();
841}
842
843static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
844 ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context;
845 GGML_ASSERT(index < ctx->devices.size());
846 return ctx->devices[index];
847}
848
849static ggml_backend_feature g_ggml_backend_metal_features[] = {
850#if defined(GGML_METAL_EMBED_LIBRARY)
851 { "EMBED_LIBRARY", "1" },
852#endif
853 { NULL, NULL },
854};
855
856static ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) {
857 return g_ggml_backend_metal_features;
858
859 GGML_UNUSED(reg);
860}
861
862static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) {
863 if (strcmp(name, "ggml_backend_get_features") == 0) {
864 return (void *)ggml_backend_metal_get_features;
865 }
866
867 return NULL;
868
869 GGML_UNUSED(reg);
870}
871
872static ggml_backend_reg_i ggml_backend_metal_reg_i = {
873 /* .get_name = */ ggml_backend_metal_reg_get_name,
874 /* .get_device_count = */ ggml_backend_metal_reg_device_count,
875 /* .get_device = */ ggml_backend_metal_reg_device_get,
876 /* .get_proc_address = */ ggml_backend_metal_get_proc_address,
877};
878
879static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) {
880 return new ggml_backend_device {
881 /* .iface = */ ggml_backend_metal_device_i,
882 /* .reg = */ reg,
883 /* .context = */ ggml_metal_device_get(device),
884 };
885}
886
887static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) {
888 delete dev;
889}
890
891struct ggml_backend_device_deleter {
892 void operator()(ggml_backend_dev_t ctx) {
893 ggml_backend_metal_device_free(ctx);
894 }
895};
896
897typedef std::unique_ptr<ggml_backend_device, ggml_backend_device_deleter> ggml_backend_device_ptr;
898
899ggml_backend_reg_t ggml_backend_metal_reg(void) {
900 static ggml_backend_reg reg;
901 static bool initialized = false;
902
903 {
904 static std::mutex mutex;
905 std::lock_guard<std::mutex> lock(mutex);
906
907 const char * env = getenv("GGML_METAL_DEVICES");
908 if (env) {
909 g_devices = atoi(env);
910 }
911
912 static std::vector<ggml_backend_device_ptr> devs;
913
914 if (!initialized) {
915 static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init());
916
917 for (int i = 0; i < g_devices; ++i) {
918 auto * dev = ggml_backend_metal_device_init(&reg, i);
919 devs.emplace_back(dev);
920
921 reg_ctx->devices.push_back(dev);
922 }
923
924 reg = {
925 /* .api_version = */ GGML_BACKEND_API_VERSION,
926 /* .iface = */ ggml_backend_metal_reg_i,
927 /* .context = */ reg_ctx.get(),
928 };
929 }
930
931 initialized = true;
932 }
933
934 return &reg;
935}
936
937GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg)
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal b/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal
new file mode 100644
index 0000000..0036ba9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal
@@ -0,0 +1,9798 @@
1#define GGML_COMMON_DECL_METAL
2#define GGML_COMMON_IMPL_METAL
3#if defined(GGML_METAL_EMBED_LIBRARY)
4__embed_ggml-common.h__
5#else
6#include "ggml-common.h"
7#endif
8#include "ggml-metal-impl.h"
9
10#include <metal_stdlib>
11
12#ifdef GGML_METAL_HAS_TENSOR
13#include <metal_tensor>
14
15#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h>
16#endif
17
18using namespace metal;
19
20#define MAX(x, y) ((x) > (y) ? (x) : (y))
21#define MIN(x, y) ((x) < (y) ? (x) : (y))
22#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; }
23
24#define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1))
25
26#define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x)
27
28#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
29
30// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
31//
32// cmd:
33// .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal
34// .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal
35//
36#if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16)
37#undef GGML_METAL_HAS_BF16
38#endif
39
40#if defined(GGML_METAL_HAS_BF16)
41typedef matrix<bfloat, 4, 4> bfloat4x4;
42typedef matrix<bfloat, 2, 4> bfloat2x4;
43#endif
44
45constexpr constant static float kvalues_iq4nl_f[16] = {
46 -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
47};
48
49constexpr constant static float kvalues_mxfp4_f[16] = {
50 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
51};
52
53static inline int best_index_int8(int n, constant float * val, float x) {
54 if (x <= val[0]) return 0;
55 if (x >= val[n-1]) return n-1;
56 int ml = 0, mu = n-1;
57 while (mu-ml > 1) {
58 int mav = (ml+mu)/2;
59 if (x < val[mav]) mu = mav; else ml = mav;
60 }
61 return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
62}
63
64static inline float e8m0_to_fp32(uint8_t x) {
65 uint32_t bits;
66
67 if (x == 0) {
68 bits = 0x00400000;
69 } else {
70 bits = (uint32_t) x << 23;
71 }
72
73 return as_type<float>(bits);
74}
75
76static inline float dot(float x, float y) {
77 return x*y;
78}
79
80// NOTE: this is not dequantizing - we are simply fitting the template
81template <typename type4x4>
82void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
83 reg = (type4x4)(*src);
84}
85
86template <typename type4>
87void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
88 reg = (type4)(*src);
89}
90
91template <typename type4x4>
92void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
93 reg = (type4x4)(*src);
94}
95
96template <typename type4>
97void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
98 reg = (type4)(*(src));
99}
100
101#if defined(GGML_METAL_HAS_BF16)
102template <typename type4x4>
103void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
104 reg = (type4x4)(*src);
105}
106
107template <typename type4>
108void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) {
109 reg = (type4)(*(src));
110}
111#endif
112
113template <typename type4x4>
114void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
115 device const uint16_t * qs = ((device const uint16_t *)xb + 1);
116 const float d1 = il ? (xb->d / 16.h) : xb->d;
117 const float d2 = d1 / 256.f;
118 const float md = -8.h * xb->d;
119 const ushort mask0 = il ? 0x00F0 : 0x000F;
120 const ushort mask1 = mask0 << 8;
121
122 float4x4 reg_f;
123
124 for (int i = 0; i < 8; i++) {
125 reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
126 reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
127 }
128
129 reg = (type4x4) reg_f;
130}
131
132template <typename type4>
133void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
134 device const uint16_t * qs = ((device const uint16_t *)xb + 1);
135 const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
136 const float d2 = d1 / 256.f;
137 const float md = -8.h * xb->d;
138 const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
139 const ushort mask1 = mask0 << 8;
140
141 for (int i = 0; i < 2; i++) {
142 reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
143 reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
144 }
145}
146
147void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
148#pragma METAL fp math_mode(safe)
149 float amax = 0.0f; // absolute max
150 float max = 0.0f;
151
152 for (int j = 0; j < QK4_0; j++) {
153 const float v = src[j];
154 if (amax < fabs(v)) {
155 amax = fabs(v);
156 max = v;
157 }
158 }
159
160 const float d = max / -8;
161 const float id = d ? 1.0f/d : 0.0f;
162
163 dst.d = d;
164
165 for (int j = 0; j < QK4_0/2; ++j) {
166 const float x0 = src[0 + j]*id;
167 const float x1 = src[QK4_0/2 + j]*id;
168
169 const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
170 const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
171
172 dst.qs[j] = xi0;
173 dst.qs[j] |= xi1 << 4;
174 }
175}
176
177void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
178#pragma METAL fp math_mode(safe)
179 float min = FLT_MAX;
180 float max = -FLT_MAX;
181
182 for (int j = 0; j < QK4_1; j++) {
183 const float v = src[j];
184 if (min > v) min = v;
185 if (max < v) max = v;
186 }
187
188 const float d = (max - min) / ((1 << 4) - 1);
189 const float id = d ? 1.0f/d : 0.0f;
190
191 dst.d = d;
192 dst.m = min;
193
194 for (int j = 0; j < QK4_1/2; ++j) {
195 const float x0 = (src[0 + j] - min)*id;
196 const float x1 = (src[QK4_1/2 + j] - min)*id;
197
198 const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
199 const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
200
201 dst.qs[j] = xi0;
202 dst.qs[j] |= xi1 << 4;
203 }
204}
205
206void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
207#pragma METAL fp math_mode(safe)
208 float amax = 0.0f; // absolute max
209 float max = 0.0f;
210
211 for (int j = 0; j < QK5_0; j++) {
212 const float v = src[j];
213 if (amax < fabs(v)) {
214 amax = fabs(v);
215 max = v;
216 }
217 }
218
219 const float d = max / -16;
220 const float id = d ? 1.0f/d : 0.0f;
221
222 dst.d = d;
223
224 uint32_t qh = 0;
225 for (int j = 0; j < QK5_0/2; ++j) {
226 const float x0 = src[0 + j]*id;
227 const float x1 = src[QK5_0/2 + j]*id;
228
229 const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
230 const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
231
232 dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
233 qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
234 qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
235 }
236
237 thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
238
239 for (int j = 0; j < 4; ++j) {
240 dst.qh[j] = qh8[j];
241 }
242}
243
244void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
245#pragma METAL fp math_mode(safe)
246 float max = src[0];
247 float min = src[0];
248
249 for (int j = 1; j < QK5_1; j++) {
250 const float v = src[j];
251 min = v < min ? v : min;
252 max = v > max ? v : max;
253 }
254
255 const float d = (max - min) / 31;
256 const float id = d ? 1.0f/d : 0.0f;
257
258 dst.d = d;
259 dst.m = min;
260
261 uint32_t qh = 0;
262 for (int j = 0; j < QK5_1/2; ++j) {
263 const float x0 = (src[0 + j] - min)*id;
264 const float x1 = (src[QK5_1/2 + j] - min)*id;
265
266 const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
267 const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
268
269 dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
270 qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
271 qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
272 }
273
274 thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
275
276 for (int j = 0; j < 4; ++j) {
277 dst.qh[j] = qh8[j];
278 }
279}
280
281void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
282#pragma METAL fp math_mode(safe)
283 float amax = 0.0f; // absolute max
284
285 for (int j = 0; j < QK8_0; j++) {
286 const float v = src[j];
287 amax = MAX(amax, fabs(v));
288 }
289
290 const float d = amax / ((1 << 7) - 1);
291 const float id = d ? 1.0f/d : 0.0f;
292
293 dst.d = d;
294
295 for (int j = 0; j < QK8_0; ++j) {
296 const float x0 = src[j]*id;
297
298 dst.qs[j] = round(x0);
299 }
300}
301
302void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
303#pragma METAL fp math_mode(safe)
304 float amax = 0.0f; // absolute max
305 float max = 0.0f;
306
307 for (int j = 0; j < QK4_NL; j++) {
308 const float v = src[j];
309 if (amax < fabs(v)) {
310 amax = fabs(v);
311 max = v;
312 }
313 }
314
315 const float d = max / kvalues_iq4nl_f[0];
316 const float id = d ? 1.0f/d : 0.0f;
317
318 float sumqx = 0, sumq2 = 0;
319 for (int j = 0; j < QK4_NL/2; ++j) {
320 const float x0 = src[0 + j]*id;
321 const float x1 = src[QK4_NL/2 + j]*id;
322
323 const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
324 const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
325
326 dst.qs[j] = xi0 | (xi1 << 4);
327
328 const float v0 = kvalues_iq4nl_f[xi0];
329 const float v1 = kvalues_iq4nl_f[xi1];
330 const float w0 = src[0 + j]*src[0 + j];
331 const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
332 sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
333 sumq2 += w0*v0*v0 + w1*v1*v1;
334
335 }
336
337 dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
338}
339
340template <typename type4x4>
341void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
342 device const uint16_t * qs = ((device const uint16_t *)xb + 2);
343 const float d1 = il ? (xb->d / 16.h) : xb->d;
344 const float d2 = d1 / 256.f;
345 const float m = xb->m;
346 const ushort mask0 = il ? 0x00F0 : 0x000F;
347 const ushort mask1 = mask0 << 8;
348
349 float4x4 reg_f;
350
351 for (int i = 0; i < 8; i++) {
352 reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
353 reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
354 }
355
356 reg = (type4x4) reg_f;
357}
358
359template <typename type4>
360void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
361 device const uint16_t * qs = ((device const uint16_t *)xb + 2);
362 const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
363 const float d2 = d1 / 256.f;
364 const float m = xb->m;
365 const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
366 const ushort mask1 = mask0 << 8;
367
368 for (int i = 0; i < 2; i++) {
369 reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
370 reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
371 }
372}
373
374template <typename type4x4>
375void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
376 device const uint16_t * qs = ((device const uint16_t *)xb + 3);
377 const float d = xb->d;
378 const float md = -16.h * xb->d;
379 const ushort mask = il ? 0x00F0 : 0x000F;
380
381 const uint32_t qh = *((device const uint32_t *)xb->qh);
382
383 const int x_mv = il ? 4 : 0;
384
385 const int gh_mv = il ? 12 : 0;
386 const int gh_bk = il ? 0 : 4;
387
388 float4x4 reg_f;
389
390 for (int i = 0; i < 8; i++) {
391 // extract the 5-th bits for x0 and x1
392 const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
393 const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
394
395 // combine the 4-bits from qs with the 5th bit
396 const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
397 const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
398
399 reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
400 reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
401 }
402
403 reg = (type4x4) reg_f;
404}
405
406template <typename type4>
407void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
408 device const uint16_t * qs = ((device const uint16_t *)xb + 3);
409 const float d = xb->d;
410 const float md = -16.h * xb->d;
411 const ushort mask = (il/4) ? 0x00F0 : 0x000F;
412
413 const uint32_t qh = *((device const uint32_t *)xb->qh);
414
415 const int x_mv = (il/4) ? 4 : 0;
416
417 const int gh_mv = (il/4) ? 12 : 0;
418 const int gh_bk = (il/4) ? 0 : 4;
419
420 for (int ii = 0; ii < 2; ii++) {
421 int i = 2*(il%4) + ii;
422
423 // extract the 5-th bits for x0 and x1
424 const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
425 const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
426
427 // combine the 4-bits from qs with the 5th bit
428 const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
429 const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
430
431 reg[2*ii + 0] = d * x0 + md;
432 reg[2*ii + 1] = d * x1 + md;
433 }
434}
435
436template <typename type4x4>
437void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
438 device const uint16_t * qs = ((device const uint16_t *)xb + 4);
439 const float d = xb->d;
440 const float m = xb->m;
441 const ushort mask = il ? 0x00F0 : 0x000F;
442
443 const uint32_t qh = *((device const uint32_t *)xb->qh);
444
445 const int x_mv = il ? 4 : 0;
446
447 const int gh_mv = il ? 12 : 0;
448 const int gh_bk = il ? 0 : 4;
449
450 float4x4 reg_f;
451
452 for (int i = 0; i < 8; i++) {
453 // extract the 5-th bits for x0 and x1
454 const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
455 const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
456
457 // combine the 4-bits from qs with the 5th bit
458 const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
459 const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
460
461 reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
462 reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
463 }
464
465 reg = (type4x4) reg_f;
466}
467
468template <typename type4>
469void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
470 device const uint16_t * qs = ((device const uint16_t *)xb + 4);
471 const float d = xb->d;
472 const float m = xb->m;
473 const ushort mask = (il/4) ? 0x00F0 : 0x000F;
474
475 const uint32_t qh = *((device const uint32_t *)xb->qh);
476
477 const int x_mv = (il/4) ? 4 : 0;
478
479 const int gh_mv = (il/4) ? 12 : 0;
480 const int gh_bk = (il/4) ? 0 : 4;
481
482 for (int ii = 0; ii < 2; ii++) {
483 int i = 2*(il%4) + ii;
484
485 // extract the 5-th bits for x0 and x1
486 const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
487 const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
488
489 // combine the 4-bits from qs with the 5th bit
490 const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
491 const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
492
493 reg[2*ii + 0] = d * x0 + m;
494 reg[2*ii + 1] = d * x1 + m;
495 }
496}
497
498template <typename type4x4>
499void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
500 device const int8_t * qs = ((device const int8_t *)xb->qs);
501 const float d = xb->d;
502
503 float4x4 reg_f;
504
505 for (int i = 0; i < 16; i++) {
506 reg_f[i/4][i%4] = (qs[i + 16*il] * d);
507 }
508
509 reg = (type4x4) reg_f;
510}
511
512template <typename type4>
513void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
514 device const int8_t * qs = ((device const int8_t *)xb->qs);
515 const float d = xb->d;
516
517 for (int i = 0; i < 4; i++) {
518 reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
519 }
520}
521
522template <typename type4x4>
523void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
524 device const uint8_t * q2 = (device const uint8_t *)xb->qs;
525
526 const float d = e8m0_to_fp32(xb->e);
527 const uint8_t shr = il >= 1 ? 4 : 0;
528
529 for (int i = 0; i < 4; ++i) {
530 reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
531 reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
532 reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
533 reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
534 }
535}
536
537template <typename type4>
538void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
539 device const uint8_t * q2 = (device const uint8_t *)xb->qs;
540
541 const float d = e8m0_to_fp32(xb->e);
542 const short il4 = il%4;
543
544 const uint8_t shr = il >= 4 ? 4 : 0;
545
546 reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
547 reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
548 reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
549 reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
550}
551
552template <typename type4x4>
553void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
554 const float d = xb->d;
555 const float min = xb->dmin;
556 device const uint8_t * q = (device const uint8_t *)xb->qs;
557 float dl, ml;
558 uint8_t sc = xb->scales[il];
559
560 q = q + 32*(il/8) + 16*(il&1);
561 il = (il/2)%4;
562
563 half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
564 uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
565 dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
566 for (int i = 0; i < 16; ++i) {
567 reg[i/4][i%4] = dl * (q[i] & mask) - ml;
568 }
569}
570
571template <typename type4x4>
572void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
573 const half d_all = xb->d;
574 device const uint8_t * q = (device const uint8_t *)xb->qs;
575 device const uint8_t * h = (device const uint8_t *)xb->hmask;
576 device const int8_t * scales = (device const int8_t *)xb->scales;
577
578 q = q + 32 * (il/8) + 16 * (il&1);
579 h = h + 16 * (il&1);
580 uint8_t m = 1 << (il/2);
581 uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
582 ((il/4)>0 ? 12 : 3);
583 uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
584 uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
585 int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
586 : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
587 float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
588 const float ml = 4.f * dl;
589
590 il = (il/2) & 3;
591 const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
592 const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
593 dl *= coef;
594
595 for (int i = 0; i < 16; ++i) {
596 reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
597 }
598}
599
600static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
601 return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
602 : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
603}
604
605template <typename type4x4>
606void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
607 device const uchar * q = xb->qs;
608
609 short is = (il/4) * 2;
610 q = q + (il/4) * 32 + 16 * (il&1);
611 il = il & 3;
612 const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
613 const float d = il < 2 ? xb->d : xb->d / 16.h;
614 const float min = xb->dmin;
615 const float dl = d * sc[0];
616 const float ml = min * sc[1];
617
618 const ushort mask = il < 2 ? 0x0F : 0xF0;
619 for (int i = 0; i < 16; ++i) {
620 reg[i/4][i%4] = dl * (q[i] & mask) - ml;
621 }
622}
623
624template <typename type4x4>
625void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
626 device const uint8_t * q = xb->qs;
627 device const uint8_t * qh = xb->qh;
628
629 short is = (il/4) * 2;
630 q = q + 32 * (il/4) + 16 * (il&1);
631 qh = qh + 16 * (il&1);
632 uint8_t ul = 1 << (il/2);
633 il = il & 3;
634 const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
635 const float d = il < 2 ? xb->d : xb->d / 16.f;
636 const float min = xb->dmin;
637 const float dl = d * sc[0];
638 const float ml = min * sc[1];
639
640 const ushort mask = il<2 ? 0x0F : 0xF0;
641 const float qh_val = il<2 ? 16.f : 256.f;
642 for (int i = 0; i < 16; ++i) {
643 reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
644 }
645}
646
647template <typename type4x4>
648void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
649 const half d_all = xb->d;
650 device const uint16_t * ql = (device const uint16_t *)xb->ql;
651 device const uint16_t * qh = (device const uint16_t *)xb->qh;
652 device const int8_t * scales = (device const int8_t *)xb->scales;
653
654 ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
655 qh = qh + 16*(il/8) + 8*(il&1);
656 float sc = scales[(il%2) + 2 * ((il/2))];
657 il = (il/2) & 3;
658
659 const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
660 const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F;
661 const float ml = d_all * sc * 32.f;
662 const float dl0 = d_all * sc;
663 const float dl1 = dl0 / 256.f;
664 const float dl2 = dl0 / (256.f * 256.f);
665 const float dl3 = dl0 / (256.f * 256.f * 256.f);
666 const uint8_t shr_h = il>2 ? 2 : 0;
667 const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
668 const uint8_t shr_l = il>1 ? 4 : 0;
669 for (int i = 0; i < 4; ++i) {
670 const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
671 const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
672 const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
673 reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml;
674 reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml;
675 reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml;
676 reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
677 }
678}
679
680template <typename type4x4>
681void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
682 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
683 const float d = xb->d;
684 const int ib32 = il/2;
685 il = il%2;
686 // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
687 // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
688 device const uint16_t * q2 = xb->qs + 4*ib32;
689 const uint32_t aux32_g = q2[0] | (q2[1] << 16);
690 const uint32_t aux32_s = q2[2] | (q2[3] << 16);
691 thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
692 const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
693 constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
694 uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
695 for (int i = 0; i < 8; ++i) {
696 reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
697 }
698 grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
699 signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
700 for (int i = 0; i < 8; ++i) {
701 reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
702 }
703}
704
705template <typename type4x4>
706void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
707 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
708 const float d = xb->d;
709 const int ib32 = il/2;
710 il = il%2;
711 // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
712 device const uint16_t * q2 = xb->qs + 4*ib32;
713 const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
714 constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
715 uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
716 for (int i = 0; i < 8; ++i) {
717 reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
718 }
719 grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
720 signs = ksigns_iq2xs[q2[2*il+1] >> 9];
721 for (int i = 0; i < 8; ++i) {
722 reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
723 }
724}
725
726template <typename type4x4>
727void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
728 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
729 const float d = xb->d;
730 const int ib32 = il/2;
731 il = il%2;
732 // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
733 device const uint8_t * q3 = xb->qs + 8*ib32;
734 device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
735 const uint32_t aux32 = gas[0] | (gas[1] << 16);
736 const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
737 constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
738 constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
739 uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
740 for (int i = 0; i < 4; ++i) {
741 reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
742 reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
743 }
744 grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
745 grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
746 signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
747 for (int i = 0; i < 4; ++i) {
748 reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
749 reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
750 }
751}
752
753template <typename type4x4>
754void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
755 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
756 const float d = xb->d;
757 const int ib32 = il/2;
758 il = il%2;
759 // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
760 device const uint8_t * qs = xb->qs + 8*ib32;
761 device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
762 const uint8_t qh = xb->qh[ib32] >> 4*il;
763 const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
764 constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
765 constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
766 for (int i = 0; i < 4; ++i) {
767 reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
768 reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
769 }
770 grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
771 grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
772 for (int i = 0; i < 4; ++i) {
773 reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
774 reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
775 }
776}
777
778template <typename type4x4>
779void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
780 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
781 const float d = xb->d;
782 const int ib32 = il/2;
783 il = il%2;
784 // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
785 device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
786 device const uint8_t * signs = qs + QK_K/8;
787 const uint8_t qh = xb->qh[ib32] >> 4*il;
788 const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
789 constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
790 constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
791 for (int i = 0; i < 8; ++i) {
792 reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
793 reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
794 }
795}
796
797template <typename type4x4>
798void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
799 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
800 const int ib32 = il/2;
801 il = il%2;
802 const float d = xb->d;
803 device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
804 device const uint16_t * qh = xb->qh;
805 const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
806 const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
807 const uint16_t h = qh[ib32] >> 6*il;
808 constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
809 constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
810 for (int i = 0; i < 4; ++i) {
811 reg[0][i] = dl * (grid1[i] & 0xf) + ml;
812 reg[1][i] = dl * (grid1[i] >> 4) + ml;
813 reg[2][i] = dl * (grid2[i] & 0xf) + ml;
814 reg[3][i] = dl * (grid2[i] >> 4) + ml;
815 }
816}
817
818template <typename type4x4>
819void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
820 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
821 const int ib32 = il/2;
822 il = il%2;
823 device const uint16_t * sc = (device const uint16_t *)xb->scales;
824
825 iq1m_scale_t scale;
826 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
827 const float d = scale.f16;
828
829 device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
830 device const uint8_t * qh = xb->qh + 2*ib32 + il;
831
832 const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
833 const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
834 const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
835 constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
836 constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
837 for (int i = 0; i < 4; ++i) {
838 reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
839 reg[1][i] = dl * (grid1[i] >> 4) + ml1;
840 reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
841 reg[3][i] = dl * (grid2[i] >> 4) + ml2;
842 }
843}
844
845template <typename type4x4>
846void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
847 device const uint16_t * q4 = (device const uint16_t *)xb->qs;
848 const float d = xb->d;
849 uint32_t aux32;
850 thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
851 for (int i = 0; i < 4; ++i) {
852 aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
853 reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
854 reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
855 reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
856 reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
857 }
858}
859
860template <typename type4>
861void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
862 device const uint16_t * q4 = (device const uint16_t *)xb->qs;
863 const float d = xb->d;
864 uint32_t aux32;
865 thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
866 aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
867 reg[0] = d * kvalues_iq4nl_f[q8[0]];
868 reg[1] = d * kvalues_iq4nl_f[q8[1]];
869 reg[2] = d * kvalues_iq4nl_f[q8[2]];
870 reg[3] = d * kvalues_iq4nl_f[q8[3]];
871}
872
873template <typename type4x4>
874void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
875 // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
876 const int ib32 = il/2;
877 il = il%2;
878 // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
879 device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
880 const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
881 const float d = (float)xb->d * (ls - 32);
882 uint32_t aux32;
883 thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
884 for (int i = 0; i < 4; ++i) {
885 aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
886 reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
887 reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
888 reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
889 reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
890 }
891}
892
893enum ggml_sort_order {
894 GGML_SORT_ORDER_ASC,
895 GGML_SORT_ORDER_DESC,
896};
897
898constant float GELU_COEF_A = 0.044715f;
899constant float GELU_QUICK_COEF = -1.702f;
900constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
901constant float SQRT_2_INV = 0.70710678118654752440084436210484f;
902
903// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation
904// ref: https://www.johndcook.com/blog/python_erf/
905constant float p_erf = 0.3275911f;
906constant float a1_erf = 0.254829592f;
907constant float a2_erf = -0.284496736f;
908constant float a3_erf = 1.421413741f;
909constant float a4_erf = -1.453152027f;
910constant float a5_erf = 1.061405429f;
911
912template<typename T>
913inline T erf_approx(T x) {
914 T sign_x = sign(x);
915 x = fabs(x);
916 T t = 1.0f / (1.0f + p_erf * x);
917 T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x);
918 return sign_x * y;
919}
920
921template<typename T> T elu_approx(T x);
922
923template<> inline float elu_approx<float>(float x) {
924 return (x > 0.f) ? x : (exp(x) - 1);
925}
926
927template<> inline float4 elu_approx<float4>(float4 x) {
928 float4 res;
929
930 res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f);
931 res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f);
932 res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f);
933 res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f);
934
935 return res;
936}
937
938constant short FC_unary_op [[function_constant(FC_UNARY + 0)]];
939constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]];
940
941template <typename T0, typename T, typename TC>
942kernel void kernel_unary_impl(
943 constant ggml_metal_kargs_unary & args,
944 device const char * src0,
945 device char * dst,
946 uint3 tgpig[[threadgroup_position_in_grid]],
947 ushort3 tpitg[[thread_position_in_threadgroup]],
948 ushort3 ntg[[threads_per_threadgroup]]) {
949#define FC_OP FC_unary_op
950#define FC_CNT FC_unary_cnt
951
952 device const T0 * src0_ptr;
953 device T * dst_ptr;
954
955 int i0;
956
957 if (FC_CNT) {
958 i0 = tgpig.x;
959
960 src0_ptr = (device const T0 *) (src0);
961 dst_ptr = (device T *) (dst);
962 } else {
963 const int i03 = tgpig.z;
964 const int i02 = tgpig.y;
965 const int k0 = tgpig.x/args.ne01;
966 const int i01 = tgpig.x - k0*args.ne01;
967
968 i0 = k0*ntg.x + tpitg.x;
969
970 src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
971 dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 );
972 }
973
974 {
975 //threadgroup_barrier(mem_flags::mem_none);
976
977 if (!FC_CNT) {
978 if (i0 >= args.ne0) {
979 return;
980 }
981 }
982
983 const TC x = (TC) src0_ptr[i0];
984
985 if (FC_OP == OP_UNARY_NUM_SCALE) {
986 dst_ptr[i0] = (T) (args.scale * x + args.bias);
987 }
988
989 if (FC_OP == OP_UNARY_NUM_FILL) {
990 dst_ptr[i0] = (T) args.val;
991 }
992
993 if (FC_OP == OP_UNARY_NUM_CLAMP) {
994 dst_ptr[i0] = (T) clamp(x, args.min, args.max);
995 }
996
997 if (FC_OP == OP_UNARY_NUM_SQR) {
998 dst_ptr[i0] = (T) (x * x);
999 }
1000
1001 if (FC_OP == OP_UNARY_NUM_SQRT) {
1002 dst_ptr[i0] = (T) sqrt(x);
1003 }
1004
1005 if (FC_OP == OP_UNARY_NUM_SIN) {
1006 dst_ptr[i0] = (T) sin(x);
1007 }
1008
1009 if (FC_OP == OP_UNARY_NUM_COS) {
1010 dst_ptr[i0] = (T) cos(x);
1011 }
1012
1013 if (FC_OP == OP_UNARY_NUM_LOG) {
1014 dst_ptr[i0] = (T) log(x);
1015 }
1016
1017 if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) {
1018 dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope));
1019 }
1020
1021 if (FC_OP == OP_UNARY_NUM_TANH) {
1022 dst_ptr[i0] = (T) precise::tanh(x);
1023 }
1024
1025 if (FC_OP == OP_UNARY_NUM_RELU) {
1026 dst_ptr[i0] = (T) fmax(0, x);
1027 }
1028
1029 if (FC_OP == OP_UNARY_NUM_SIGMOID) {
1030 dst_ptr[i0] = (T) (1 / (1 + exp(-x)));
1031 }
1032
1033 if (FC_OP == OP_UNARY_NUM_GELU) {
1034 dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x))));
1035 }
1036
1037 if (FC_OP == OP_UNARY_NUM_GELU_ERF) {
1038 dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x)));
1039 }
1040
1041 if (FC_OP == OP_UNARY_NUM_GELU_QUICK) {
1042 dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x))));
1043 }
1044
1045 if (FC_OP == OP_UNARY_NUM_SILU) {
1046 dst_ptr[i0] = (T) (x / (1 + exp(-x)));
1047 }
1048
1049 if (FC_OP == OP_UNARY_NUM_ELU) {
1050 dst_ptr[i0] = (T) elu_approx(x);
1051 }
1052
1053 if (FC_OP == OP_UNARY_NUM_NEG) {
1054 dst_ptr[i0] = (T) -x;
1055 }
1056
1057 if (FC_OP == OP_UNARY_NUM_ABS) {
1058 dst_ptr[i0] = (T) fabs(x);
1059 }
1060
1061 if (FC_OP == OP_UNARY_NUM_SGN) {
1062 dst_ptr[i0] = T(x > 0) - T(x < 0);
1063 }
1064
1065 if (FC_OP == OP_UNARY_NUM_STEP) {
1066 dst_ptr[i0] = T(x > 0);
1067 }
1068
1069 if (FC_OP == OP_UNARY_NUM_HARDSWISH) {
1070 dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5)));
1071 }
1072
1073 if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) {
1074 dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5));
1075 }
1076
1077 if (FC_OP == OP_UNARY_NUM_EXP) {
1078 dst_ptr[i0] = (T) exp(x);
1079 }
1080
1081 if (FC_OP == OP_UNARY_NUM_SOFTPLUS) {
1082 dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20);
1083 }
1084
1085 if (FC_OP == OP_UNARY_NUM_EXPM1) {
1086 // TODO: precise implementation
1087 dst_ptr[i0] = (T) (exp(x) - 1);
1088 }
1089 }
1090
1091#undef FC_OP
1092#undef FC_CNT
1093}
1094
1095typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t;
1096
1097template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>;
1098template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>;
1099template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>;
1100template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>;
1101
1102// OP: 0 - add, 1 - sub, 2 - mul, 3 - div
1103constant short FC_bin_op [[function_constant(FC_BIN + 0)]];
1104constant short FC_bin_f [[function_constant(FC_BIN + 1)]];
1105constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]];
1106
1107template <typename T0, typename T1, typename T>
1108kernel void kernel_bin_fuse_impl(
1109 constant ggml_metal_kargs_bin & args,
1110 device const char * src0,
1111 device const char * src1,
1112 device char * dst,
1113 uint3 tgpig[[threadgroup_position_in_grid]],
1114 ushort3 tpitg[[thread_position_in_threadgroup]],
1115 ushort3 ntg[[threads_per_threadgroup]]) {
1116#define FC_OP FC_bin_op
1117#define FC_F FC_bin_f
1118#define FC_RB FC_bin_rb
1119
1120 if (FC_RB) {
1121 // row broadcast
1122 const uint i0 = tgpig.x;
1123 const uint i1 = i0%args.ne10;
1124
1125 device const T0 * src0_row = (device const T0 *) (src0);
1126 device T * dst_row = (device T *) (dst);
1127
1128 if (FC_F == 1) {
1129 device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]);
1130
1131 if (FC_OP == 0) {
1132 dst_row[i0] = src0_row[i0] + src1_row[i1];
1133 }
1134
1135 if (FC_OP == 1) {
1136 dst_row[i0] = src0_row[i0] - src1_row[i1];
1137 }
1138
1139 if (FC_OP == 2) {
1140 dst_row[i0] = src0_row[i0] * src1_row[i1];
1141 }
1142
1143 if (FC_OP == 3) {
1144 dst_row[i0] = src0_row[i0] / src1_row[i1];
1145 }
1146 } else {
1147 T0 res = src0_row[i0];
1148
1149 if (FC_OP == 0) {
1150 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1151 res += ((device const T1 *) (src1 + args.o1[j]))[i1];
1152 }
1153 }
1154
1155 if (FC_OP == 1) {
1156 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1157 res -= ((device const T1 *) (src1 + args.o1[j]))[i1];
1158 }
1159 }
1160
1161 if (FC_OP == 2) {
1162 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1163 res *= ((device const T1 *) (src1 + args.o1[j]))[i1];
1164 }
1165 }
1166
1167 if (FC_OP == 3) {
1168 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1169 res /= ((device const T1 *) (src1 + args.o1[j]))[i1];
1170 }
1171 }
1172
1173 dst_row[i0] = res;
1174 }
1175 } else {
1176 const int i03 = tgpig.z;
1177 const int i02 = tgpig.y;
1178 const int i01 = tgpig.x;
1179
1180 if (i01 >= args.ne01) {
1181 return;
1182 }
1183
1184 const int i13 = i03%args.ne13;
1185 const int i12 = i02%args.ne12;
1186 const int i11 = i01%args.ne11;
1187
1188 device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs);
1189 device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs);
1190
1191 if (FC_F == 1) {
1192 device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1193
1194 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1195 const int i10 = i0%args.ne10;
1196
1197 if (FC_OP == 0) {
1198 dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10];
1199 }
1200
1201 if (FC_OP == 1) {
1202 dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10];
1203 }
1204
1205 if (FC_OP == 2) {
1206 dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10];
1207 }
1208
1209 if (FC_OP == 3) {
1210 dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10];
1211 }
1212 }
1213 } else {
1214 device const T1 * src1_ptr[8];
1215 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1216 src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11);
1217 }
1218
1219 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1220 const int i10 = i0%args.ne10;
1221
1222 T res = src0_ptr[i0];
1223
1224 if (FC_OP == 0) {
1225 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1226 res += src1_ptr[j][i10];
1227 }
1228 }
1229
1230 if (FC_OP == 1) {
1231 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1232 res -= src1_ptr[j][i10];
1233 }
1234 }
1235
1236 if (FC_OP == 2) {
1237 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1238 res *= src1_ptr[j][i10];
1239 }
1240 }
1241
1242 if (FC_OP == 3) {
1243 FOR_UNROLL (short j = 0; j < FC_F; ++j) {
1244 res /= src1_ptr[j][i10];
1245 }
1246 }
1247
1248 dst_ptr[i0] = res;
1249 }
1250 }
1251 }
1252
1253#undef FC_OP
1254#undef FC_F
1255#undef FC_RB
1256}
1257
1258typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t;
1259
1260template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>;
1261template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>;
1262
1263kernel void kernel_add_id(
1264 constant ggml_metal_kargs_add_id & args,
1265 device const char * src0,
1266 device const char * src1,
1267 device const char * src2,
1268 device char * dst,
1269 uint3 tgpig[[threadgroup_position_in_grid]],
1270 ushort3 tpitg[[thread_position_in_threadgroup]],
1271 ushort3 ntg[[threads_per_threadgroup]]) {
1272 const int i1 = tgpig.x;
1273 const int i2 = tgpig.y;
1274
1275 const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
1276
1277 const size_t nb1 = args.ne0 * sizeof(float);
1278 const size_t nb2 = args.ne1 * nb1;
1279
1280 device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
1281 device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
1282 device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
1283
1284 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1285 dst_row[i0] = src0_row[i0] + src1_row[i0];
1286 }
1287}
1288
1289template<typename T>
1290kernel void kernel_repeat(
1291 constant ggml_metal_kargs_repeat & args,
1292 device const char * src0,
1293 device char * dst,
1294 uint3 tgpig[[threadgroup_position_in_grid]],
1295 ushort3 tpitg[[thread_position_in_threadgroup]],
1296 ushort3 ntg[[threads_per_threadgroup]]) {
1297 const int i3 = tgpig.z;
1298 const int i2 = tgpig.y;
1299 const int i1 = tgpig.x;
1300
1301 const int i03 = i3%args.ne03;
1302 const int i02 = i2%args.ne02;
1303 const int i01 = i1%args.ne01;
1304
1305 device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
1306 device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
1307
1308 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
1309 const int i00 = i0%args.ne00;
1310 *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
1311 }
1312}
1313
1314typedef decltype(kernel_repeat<float>) kernel_repeat_t;
1315
1316template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
1317template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
1318template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
1319template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
1320
1321kernel void kernel_reglu_f32(
1322 constant ggml_metal_kargs_glu & args,
1323 device const char * src0,
1324 device const char * src1,
1325 device char * dst,
1326 uint tgpig[[threadgroup_position_in_grid]],
1327 uint tpitg[[thread_position_in_threadgroup]],
1328 uint ntg[[threads_per_threadgroup]]) {
1329 device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1330 device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1331 device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1332
1333 for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1334 const float x0 = src0_row[i0];
1335 const float x1 = src1_row[i0];
1336
1337 dst_row[i0] = x0*x1*(x0 > 0.0f);
1338 }
1339}
1340
1341kernel void kernel_geglu_f32(
1342 constant ggml_metal_kargs_glu & args,
1343 device const char * src0,
1344 device const char * src1,
1345 device char * dst,
1346 uint tgpig[[threadgroup_position_in_grid]],
1347 uint tpitg[[thread_position_in_threadgroup]],
1348 uint ntg[[threads_per_threadgroup]]) {
1349 device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1350 device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1351 device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1352
1353 for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1354 const float x0 = src0_row[i0];
1355 const float x1 = src1_row[i0];
1356
1357 const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
1358
1359 dst_row[i0] = gelu*x1;
1360 }
1361}
1362
1363kernel void kernel_swiglu_f32(
1364 constant ggml_metal_kargs_glu & args,
1365 device const char * src0,
1366 device const char * src1,
1367 device char * dst,
1368 uint tgpig[[threadgroup_position_in_grid]],
1369 uint tpitg[[thread_position_in_threadgroup]],
1370 uint ntg[[threads_per_threadgroup]]) {
1371 device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1372 device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1373 device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1374
1375 for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1376 const float x0 = src0_row[i0];
1377 const float x1 = src1_row[i0];
1378
1379 const float silu = x0 / (1.0f + exp(-x0));
1380
1381 dst_row[i0] = silu*x1;
1382 }
1383}
1384
1385kernel void kernel_swiglu_oai_f32(
1386 constant ggml_metal_kargs_glu & args,
1387 device const char * src0,
1388 device const char * src1,
1389 device char * dst,
1390 uint tgpig[[threadgroup_position_in_grid]],
1391 uint tpitg[[thread_position_in_threadgroup]],
1392 uint ntg[[threads_per_threadgroup]]) {
1393 device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1394 device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1395 device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1396
1397 for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1398 float x0 = src0_row[i0];
1399 float x1 = src1_row[i0];
1400
1401 x0 = min(x0, args.limit);
1402 x1 = max(min(x1, args.limit), -args.limit);
1403
1404 float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
1405 out_glu = out_glu * (1.0f + x1);
1406
1407 dst_row[i0] = out_glu;
1408 }
1409}
1410
1411kernel void kernel_geglu_erf_f32(
1412 constant ggml_metal_kargs_glu & args,
1413 device const char * src0,
1414 device const char * src1,
1415 device char * dst,
1416 uint tgpig[[threadgroup_position_in_grid]],
1417 uint tpitg[[thread_position_in_threadgroup]],
1418 uint ntg[[threads_per_threadgroup]]) {
1419 device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1420 device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1421 device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1422
1423 for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1424 const float x0 = src0_row[i0];
1425 const float x1 = src1_row[i0];
1426
1427 const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV));
1428
1429 dst_row[i0] = gelu_erf*x1;
1430 }
1431}
1432
1433kernel void kernel_geglu_quick_f32(
1434 constant ggml_metal_kargs_glu & args,
1435 device const char * src0,
1436 device const char * src1,
1437 device char * dst,
1438 uint tgpig[[threadgroup_position_in_grid]],
1439 uint tpitg[[thread_position_in_threadgroup]],
1440 uint ntg[[threads_per_threadgroup]]) {
1441 device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
1442 device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
1443 device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
1444
1445 for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
1446 const float x0 = src0_row[i0];
1447 const float x1 = src1_row[i0];
1448
1449 const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0)));
1450
1451 dst_row[i0] = gelu_quick*x1;
1452 }
1453}
1454
1455kernel void kernel_op_sum_f32(
1456 constant ggml_metal_kargs_sum & args,
1457 device const float * src0,
1458 device float * dst,
1459 threadgroup float * shmem_f32 [[threadgroup(0)]],
1460 uint3 tgpig[[threadgroup_position_in_grid]],
1461 ushort3 tpitg[[thread_position_in_threadgroup]],
1462 ushort sgitg[[simdgroup_index_in_threadgroup]],
1463 ushort tiisg[[thread_index_in_simdgroup]],
1464 ushort3 ntg[[threads_per_threadgroup]]) {
1465
1466 if (args.np == 0) {
1467 return;
1468 }
1469
1470 // TODO: become function constant
1471 const uint nsg = (ntg.x + 31) / 32;
1472
1473 float sumf = 0;
1474
1475 for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
1476 sumf += src0[i0];
1477 }
1478
1479 sumf = simd_sum(sumf);
1480
1481 if (tiisg == 0) {
1482 shmem_f32[sgitg] = sumf;
1483 }
1484
1485 threadgroup_barrier(mem_flags::mem_threadgroup);
1486
1487 float total = 0;
1488
1489 if (sgitg == 0) {
1490 float v = 0;
1491
1492 if (tpitg.x < nsg) {
1493 v = shmem_f32[tpitg.x];
1494 }
1495
1496 total = simd_sum(v);
1497
1498 if (tpitg.x == 0) {
1499 dst[0] = total;
1500 }
1501 }
1502}
1503
1504template <bool norm>
1505kernel void kernel_sum_rows(
1506 constant ggml_metal_kargs_sum_rows & args,
1507 device const float * src0,
1508 device float * dst,
1509 threadgroup float * shmem_f32 [[threadgroup(0)]],
1510 uint3 tgpig[[threadgroup_position_in_grid]],
1511 ushort3 tpitg[[thread_position_in_threadgroup]],
1512 ushort sgitg[[simdgroup_index_in_threadgroup]],
1513 ushort tiisg[[thread_index_in_simdgroup]],
1514 ushort3 ntg[[threads_per_threadgroup]]) {
1515 int64_t i3 = tgpig.z;
1516 int64_t i2 = tgpig.y;
1517 int64_t i1 = tgpig.x;
1518
1519 if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1520 return;
1521 }
1522
1523 if (sgitg == 0) {
1524 shmem_f32[tiisg] = 0.0f;
1525 }
1526
1527 device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1528 device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1529
1530 float sumf = 0;
1531
1532 for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1533 sumf += src_row[i0];
1534 }
1535
1536 sumf = simd_sum(sumf);
1537
1538 threadgroup_barrier(mem_flags::mem_threadgroup);
1539
1540 if (tiisg == 0) {
1541 shmem_f32[sgitg] = sumf;
1542 }
1543
1544 threadgroup_barrier(mem_flags::mem_threadgroup);
1545
1546 sumf = shmem_f32[tiisg];
1547 sumf = simd_sum(sumf);
1548
1549 if (tpitg.x == 0) {
1550 dst_row[0] = norm ? sumf / args.ne00 : sumf;
1551 }
1552}
1553
1554typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t;
1555
1556template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>;
1557template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>;
1558
1559template<typename T>
1560kernel void kernel_cumsum_blk(
1561 constant ggml_metal_kargs_cumsum_blk & args,
1562 device const char * src0,
1563 device char * tmp,
1564 device char * dst,
1565 threadgroup char * shmem [[threadgroup(0)]],
1566 uint3 tgpig[[threadgroup_position_in_grid]],
1567 ushort3 tpitg[[thread_position_in_threadgroup]],
1568 ushort sgitg[[simdgroup_index_in_threadgroup]],
1569 ushort tiisg[[thread_index_in_simdgroup]],
1570 ushort3 ntg[[threads_per_threadgroup]]) {
1571 const int ib = tgpig[0]/args.ne01;
1572
1573 const int i00 = ib*ntg.x;
1574 const int i01 = tgpig[0]%args.ne01;
1575 const int i02 = tgpig[1];
1576 const int i03 = tgpig[2];
1577
1578 device const float * src0_row = (device const float *) (src0 +
1579 args.nb01*i01 +
1580 args.nb02*i02 +
1581 args.nb03*i03);
1582
1583 threadgroup float * shmem_f32 = (threadgroup float *) shmem;
1584
1585 float v = 0.0f;
1586
1587 if (i00 + tpitg.x < args.ne00) {
1588 v = src0_row[i00 + tpitg.x];
1589 }
1590
1591 float s = simd_prefix_inclusive_sum(v);
1592
1593 if (tiisg == N_SIMDWIDTH - 1) {
1594 shmem_f32[sgitg] = s;
1595 }
1596
1597 threadgroup_barrier(mem_flags::mem_threadgroup);
1598
1599 if (sgitg == 0) {
1600 shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]);
1601 }
1602
1603 threadgroup_barrier(mem_flags::mem_threadgroup);
1604
1605 s += shmem_f32[sgitg];
1606
1607 device float * dst_row = (device float *) dst +
1608 args.ne00*i01 +
1609 args.ne00*args.ne01*i02 +
1610 args.ne00*args.ne01*args.ne02*i03;
1611
1612 if (i00 + tpitg.x < args.ne00) {
1613 dst_row[i00 + tpitg.x] = s;
1614 }
1615
1616 if (args.outb && tpitg.x == ntg.x - 1) {
1617 device float * tmp_row = (device float *) tmp +
1618 args.net0*i01 +
1619 args.net0*args.net1*i02 +
1620 args.net0*args.net1*args.net2*i03;
1621
1622 tmp_row[ib] = s;
1623 }
1624}
1625
1626typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t;
1627
1628template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>;
1629
1630template<typename T>
1631kernel void kernel_cumsum_add(
1632 constant ggml_metal_kargs_cumsum_add & args,
1633 device const char * tmp,
1634 device char * dst,
1635 uint3 tgpig[[threadgroup_position_in_grid]],
1636 ushort3 tpitg[[thread_position_in_threadgroup]],
1637 ushort sgitg[[simdgroup_index_in_threadgroup]],
1638 ushort tiisg[[thread_index_in_simdgroup]],
1639 ushort3 ntg[[threads_per_threadgroup]]) {
1640 const int ib = tgpig[0]/args.ne01;
1641
1642 if (ib == 0) {
1643 return;
1644 }
1645
1646 const int i00 = ib*ntg.x;
1647 const int i01 = tgpig[0]%args.ne01;
1648 const int i02 = tgpig[1];
1649 const int i03 = tgpig[2];
1650
1651 device const float * tmp_row = (device const float *) (tmp +
1652 args.nbt1*i01 +
1653 args.nbt2*i02 +
1654 args.nbt3*i03);
1655
1656 device float * dst_row = (device float *) dst +
1657 args.ne00*i01 +
1658 args.ne00*args.ne01*i02 +
1659 args.ne00*args.ne01*args.ne02*i03;
1660
1661 if (i00 + tpitg.x < args.ne00) {
1662 dst_row[i00 + tpitg.x] += tmp_row[ib - 1];
1663 }
1664}
1665
1666typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t;
1667
1668template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>;
1669
1670
1671template<uint32_t ttype>
1672bool _ggml_vec_tri_cmp(const int i, const int r);
1673
1674template<>
1675bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) {
1676 return i < r;
1677}
1678
1679template<>
1680bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) {
1681 return i <= r;
1682}
1683
1684template<>
1685bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) {
1686 return i > r;
1687}
1688
1689template<>
1690bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) {
1691 return i >= r;
1692}
1693
1694template<typename T, int ttype>
1695kernel void kernel_tri(
1696 constant ggml_metal_kargs_tri & args,
1697 device const char * src0,
1698 device const char * dst,
1699 uint3 tgpig[[threadgroup_position_in_grid]],
1700 ushort3 tpitg[[thread_position_in_threadgroup]],
1701 ushort3 ntg[[threads_per_threadgroup]]) {
1702 const int i3 = tgpig.z;
1703 const int i2 = tgpig.y;
1704 const int i1 = tgpig.x;
1705
1706 if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
1707 return;
1708 }
1709
1710 device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
1711 device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3);
1712
1713 // Each thread is a single element of the row if ne00 < max threads per
1714 // threadgroup, so this will loop once for each index that this thread is
1715 // responsible for
1716 for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
1717 // Use the comparison as a mask for branchless
1718 dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0];
1719 }
1720}
1721
1722typedef decltype(kernel_tri<float, 0>) kernel_tri_t;
1723
1724template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>;
1725template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>;
1726template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>;
1727template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>;
1728template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>;
1729template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>;
1730template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>;
1731template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>;
1732#if defined(GGML_METAL_HAS_BF16)
1733template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>;
1734template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>;
1735template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>;
1736template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>;
1737#endif
1738
1739template<typename T>
1740kernel void kernel_soft_max(
1741 constant ggml_metal_kargs_soft_max & args,
1742 device const char * src0,
1743 device const char * src1,
1744 device const char * src2,
1745 device char * dst,
1746 threadgroup float * buf [[threadgroup(0)]],
1747 uint3 tgpig[[threadgroup_position_in_grid]],
1748 uint3 tpitg[[thread_position_in_threadgroup]],
1749 uint sgitg[[simdgroup_index_in_threadgroup]],
1750 uint tiisg[[thread_index_in_simdgroup]],
1751 uint3 tptg[[threads_per_threadgroup]]) {
1752 const int32_t i03 = tgpig.z;
1753 const int32_t i02 = tgpig.y;
1754 const int32_t i01 = tgpig.x;
1755
1756 const int32_t i13 = i03%args.ne13;
1757 const int32_t i12 = i02%args.ne12;
1758 const int32_t i11 = i01;
1759
1760 device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1761 device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1762 device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
1763 device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1764
1765 float slope = 1.0f;
1766
1767 // ALiBi
1768 if (args.max_bias > 0.0f) {
1769 const int32_t h = i02;
1770
1771 const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1772 const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1773
1774 slope = pow(base, exp);
1775 }
1776
1777 // parallel max
1778 float lmax = psrc2 ? psrc2[i02] : -INFINITY;
1779
1780 for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1781 lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
1782 }
1783
1784 // find the max value in the block
1785 float max_val = simd_max(lmax);
1786 if (tptg.x > N_SIMDWIDTH) {
1787 if (sgitg == 0) {
1788 buf[tiisg] = -INFINITY;
1789 }
1790
1791 threadgroup_barrier(mem_flags::mem_threadgroup);
1792
1793 if (tiisg == 0) {
1794 buf[sgitg] = max_val;
1795 }
1796
1797 threadgroup_barrier(mem_flags::mem_threadgroup);
1798
1799 max_val = buf[tiisg];
1800 max_val = simd_max(max_val);
1801 }
1802
1803 // parallel sum
1804 float lsum = 0.0f;
1805 for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1806 const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
1807 lsum += exp_psrc0;
1808 pdst[i00] = exp_psrc0;
1809 }
1810
1811 // This barrier fixes a failing test
1812 // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
1813 threadgroup_barrier(mem_flags::mem_none);
1814
1815 float sum = simd_sum(lsum);
1816
1817 if (tptg.x > N_SIMDWIDTH) {
1818 if (sgitg == 0) {
1819 buf[tiisg] = 0.0f;
1820 }
1821
1822 threadgroup_barrier(mem_flags::mem_threadgroup);
1823
1824 if (tiisg == 0) {
1825 buf[sgitg] = sum;
1826 }
1827
1828 threadgroup_barrier(mem_flags::mem_threadgroup);
1829
1830 sum = buf[tiisg];
1831 sum = simd_sum(sum);
1832 }
1833
1834 if (psrc2) {
1835 sum += exp(psrc2[i02] - max_val);
1836 }
1837
1838 const float inv_sum = 1.0f/sum;
1839
1840 for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
1841 pdst[i00] *= inv_sum;
1842 }
1843}
1844
1845template<typename T>
1846kernel void kernel_soft_max_4(
1847 constant ggml_metal_kargs_soft_max & args,
1848 device const char * src0,
1849 device const char * src1,
1850 device const char * src2,
1851 device char * dst,
1852 threadgroup float * buf [[threadgroup(0)]],
1853 uint3 tgpig[[threadgroup_position_in_grid]],
1854 uint3 tpitg[[thread_position_in_threadgroup]],
1855 uint sgitg[[simdgroup_index_in_threadgroup]],
1856 uint tiisg[[thread_index_in_simdgroup]],
1857 uint3 tptg[[threads_per_threadgroup]]) {
1858 const int32_t i03 = tgpig.z;
1859 const int32_t i02 = tgpig.y;
1860 const int32_t i01 = tgpig.x;
1861
1862 const int32_t i13 = i03%args.ne13;
1863 const int32_t i12 = i02%args.ne12;
1864 const int32_t i11 = i01;
1865
1866 device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
1867 device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
1868 device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
1869 device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
1870
1871 float slope = 1.0f;
1872
1873 if (args.max_bias > 0.0f) {
1874 const int32_t h = i02;
1875
1876 const float base = h < args.n_head_log2 ? args.m0 : args.m1;
1877 const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
1878
1879 slope = pow(base, exp);
1880 }
1881
1882 // parallel max
1883 float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
1884
1885 for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1886 lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
1887 }
1888
1889 const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
1890
1891 float max_val = simd_max(lmax);
1892 if (tptg.x > N_SIMDWIDTH) {
1893 if (sgitg == 0) {
1894 buf[tiisg] = -INFINITY;
1895 }
1896
1897 threadgroup_barrier(mem_flags::mem_threadgroup);
1898
1899 if (tiisg == 0) {
1900 buf[sgitg] = max_val;
1901 }
1902
1903 threadgroup_barrier(mem_flags::mem_threadgroup);
1904
1905 max_val = buf[tiisg];
1906 max_val = simd_max(max_val);
1907 }
1908
1909 // parallel sum
1910 float4 lsum4 = 0.0f;
1911 for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1912 const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
1913 lsum4 += exp_psrc4;
1914 pdst4[i00] = exp_psrc4;
1915 }
1916
1917 const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
1918
1919 // This barrier fixes a failing test
1920 // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335
1921 threadgroup_barrier(mem_flags::mem_none);
1922
1923 float sum = simd_sum(lsum);
1924
1925 if (tptg.x > N_SIMDWIDTH) {
1926 if (sgitg == 0) {
1927 buf[tiisg] = 0.0f;
1928 }
1929
1930 threadgroup_barrier(mem_flags::mem_threadgroup);
1931
1932 if (tiisg == 0) {
1933 buf[sgitg] = sum;
1934 }
1935
1936 threadgroup_barrier(mem_flags::mem_threadgroup);
1937
1938 sum = buf[tiisg];
1939 sum = simd_sum(sum);
1940 }
1941
1942 if (psrc2) {
1943 sum += exp(psrc2[i02] - max_val);
1944 }
1945
1946 const float inv_sum = 1.0f/sum;
1947
1948 for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
1949 pdst4[i00] *= inv_sum;
1950 }
1951}
1952
1953typedef decltype(kernel_soft_max<float>) kernel_soft_max_t;
1954typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t;
1955
1956template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>;
1957template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>;
1958template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>;
1959template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>;
1960
1961// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
1962kernel void kernel_ssm_conv_f32_f32(
1963 constant ggml_metal_kargs_ssm_conv & args,
1964 device const void * src0,
1965 device const void * src1,
1966 device float * dst,
1967 uint3 tgpig[[threadgroup_position_in_grid]],
1968 uint3 tpitg[[thread_position_in_threadgroup]],
1969 uint3 ntg[[threads_per_threadgroup]]) {
1970 const int64_t ir = tgpig.x;
1971 const int64_t i2 = tgpig.y;
1972 const int64_t i3 = tgpig.z;
1973
1974 const int64_t nc = args.ne10;
1975 //const int64_t ncs = args.ne00;
1976 //const int64_t nr = args.ne01;
1977 //const int64_t n_t = args.ne1;
1978 //const int64_t n_s = args.ne2;
1979
1980 device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
1981 device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
1982 device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
1983
1984 float sumf = 0.0f;
1985
1986 for (int64_t i0 = 0; i0 < nc; ++i0) {
1987 sumf += s[i0] * c[i0];
1988 }
1989
1990 x[0] = sumf;
1991}
1992
1993kernel void kernel_ssm_conv_f32_f32_4(
1994 constant ggml_metal_kargs_ssm_conv & args,
1995 device const void * src0,
1996 device const void * src1,
1997 device float * dst,
1998 uint3 tgpig[[threadgroup_position_in_grid]],
1999 uint3 tpitg[[thread_position_in_threadgroup]],
2000 uint3 ntg[[threads_per_threadgroup]]) {
2001 const int64_t ir = tgpig.x;
2002 const int64_t i2 = tgpig.y;
2003 const int64_t i3 = tgpig.z;
2004
2005 const int64_t nc = args.ne10;
2006 //const int64_t ncs = args.ne00;
2007 //const int64_t nr = args.ne01;
2008 //const int64_t n_t = args.ne1;
2009 //const int64_t n_s = args.ne2;
2010
2011 device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2012 device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
2013 device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2014
2015 float sumf = 0.0f;
2016
2017 for (int64_t i0 = 0; i0 < nc/4; ++i0) {
2018 sumf += dot(s[i0], c[i0]);
2019 }
2020
2021 x[0] = sumf;
2022}
2023
2024constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]];
2025
2026// Batched version: each threadgroup processes multiple tokens for better efficiency
2027// Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens
2028kernel void kernel_ssm_conv_f32_f32_batched(
2029 constant ggml_metal_kargs_ssm_conv & args,
2030 device const void * src0,
2031 device const void * src1,
2032 device float * dst,
2033 uint3 tgpig[[threadgroup_position_in_grid]],
2034 uint3 tpitg[[thread_position_in_threadgroup]],
2035 uint3 ntg[[threads_per_threadgroup]]) {
2036 // tgpig.x = row index (ir)
2037 // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2038 // tgpig.z = sequence index (i3)
2039 // tpitg.x = thread within batch (0..BATCH_SIZE-1)
2040 const short BATCH_SIZE = FC_ssm_conv_bs;
2041
2042 const int64_t ir = tgpig.x;
2043 const int64_t i2_base = tgpig.y * BATCH_SIZE;
2044 const int64_t i3 = tgpig.z;
2045 const int64_t i2_off = tpitg.x;
2046 const int64_t i2 = i2_base + i2_off;
2047
2048 const int64_t nc = args.ne10; // conv kernel size (typically 4)
2049 const int64_t n_t = args.ne1; // number of tokens
2050
2051 // Bounds check for partial batches at the end
2052 if (i2 >= n_t) {
2053 return;
2054 }
2055
2056 // Load conv weights (shared across all tokens for this row)
2057 device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
2058
2059 // Load source for this specific token
2060 device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2061
2062 // Output location for this token
2063 device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2064
2065 float sumf = 0.0f;
2066 for (int64_t i0 = 0; i0 < nc; ++i0) {
2067 sumf += s[i0] * c[i0];
2068 }
2069
2070 x[0] = sumf;
2071}
2072
2073kernel void kernel_ssm_conv_f32_f32_batched_4(
2074 constant ggml_metal_kargs_ssm_conv & args,
2075 device const void * src0,
2076 device const void * src1,
2077 device float * dst,
2078 uint3 tgpig[[threadgroup_position_in_grid]],
2079 uint3 tpitg[[thread_position_in_threadgroup]],
2080 uint3 ntg[[threads_per_threadgroup]]) {
2081 // tgpig.x = row index (ir)
2082 // tgpig.y = batch of tokens (i2_base / BATCH_SIZE)
2083 // tgpig.z = sequence index (i3)
2084 // tpitg.x = thread within batch (0..BATCH_SIZE-1)
2085 const short BATCH_SIZE = FC_ssm_conv_bs;
2086
2087 const int64_t ir = tgpig.x;
2088 const int64_t i2_base = tgpig.y * BATCH_SIZE;
2089 const int64_t i3 = tgpig.z;
2090 const int64_t i2_off = tpitg.x;
2091 const int64_t i2 = i2_base + i2_off;
2092
2093 const int64_t nc = args.ne10; // conv kernel size (typically 4)
2094 const int64_t n_t = args.ne1; // number of tokens
2095
2096 // Bounds check for partial batches at the end
2097 if (i2 >= n_t) {
2098 return;
2099 }
2100
2101 // Load conv weights (shared across all tokens for this row)
2102 device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
2103
2104 // Load source for this specific token
2105 device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
2106
2107 // Output location for this token
2108 device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
2109
2110 float sumf = 0.0f;
2111 for (int64_t i0 = 0; i0 < nc/4; ++i0) {
2112 sumf += dot(s[i0], c[i0]);
2113 }
2114
2115 x[0] = sumf;
2116}
2117
2118// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
2119// Optimized version: reduces redundant memory loads by having one thread load shared values
2120kernel void kernel_ssm_scan_f32(
2121 constant ggml_metal_kargs_ssm_scan & args,
2122 device const void * src0,
2123 device const void * src1,
2124 device const void * src2,
2125 device const void * src3,
2126 device const void * src4,
2127 device const void * src5,
2128 device const void * src6,
2129 device float * dst,
2130 threadgroup float * shared [[threadgroup(0)]],
2131 uint3 tgpig[[threadgroup_position_in_grid]],
2132 ushort3 tpitg[[thread_position_in_threadgroup]],
2133 ushort sgitg[[simdgroup_index_in_threadgroup]],
2134 ushort tiisg[[thread_index_in_simdgroup]],
2135 ushort sgptg[[simdgroups_per_threadgroup]],
2136 uint3 tgpg[[threadgroups_per_grid]]) {
2137 constexpr short NW = N_SIMDWIDTH;
2138
2139 // Shared memory layout:
2140 // [0..sgptg*NW-1]: partial sums for reduction (existing)
2141 // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch
2142 // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch
2143 threadgroup float * shared_sums = shared;
2144 threadgroup float * shared_x_dt = shared + sgptg * NW;
2145 threadgroup float * shared_dA = shared + sgptg * NW + sgptg;
2146
2147 shared_sums[tpitg.x] = 0.0f;
2148
2149 const int32_t i0 = tpitg.x;
2150 const int32_t i1 = tgpig.x;
2151 const int32_t ir = tgpig.y; // current head
2152 const int32_t i3 = tgpig.z; // current seq
2153
2154 const int32_t nc = args.d_state;
2155 const int32_t nr = args.d_inner;
2156 const int32_t nh = args.n_head;
2157 const int32_t ng = args.n_group;
2158 const int32_t n_t = args.n_seq_tokens;
2159
2160 const int32_t s_off = args.s_off;
2161
2162 device const int32_t * ids = (device const int32_t *) src6;
2163
2164 device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
2165 device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
2166
2167 const int32_t i = i0 + i1*nc;
2168 const int32_t g = ir / (nh / ng); // repeat_interleave
2169
2170 float s0 = s0_buff[i];
2171 float s = 0.0f;
2172
2173 device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
2174
2175 const float A0 = A[i0%args.ne30];
2176
2177 device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
2178 device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
2179 device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
2180 device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
2181
2182 device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
2183
2184 for (int i2 = 0; i2 < n_t; i2 += sgptg) {
2185 threadgroup_barrier(mem_flags::mem_threadgroup);
2186
2187 // Pre-compute x_dt and dA for this batch of tokens
2188 // Only first sgptg threads do the loads and expensive math
2189 if (i0 < sgptg && i2 + i0 < n_t) {
2190 // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20)
2191 device const float * x_t = x + i0 * args.ns12;
2192 device const float * dt_t = dt + i0 * args.ns21;
2193
2194 const float dt0 = dt_t[0];
2195 const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
2196 shared_x_dt[i0] = x_t[0] * dtsp;
2197 shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies
2198 }
2199
2200 threadgroup_barrier(mem_flags::mem_threadgroup);
2201
2202 for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
2203 const float x_dt = shared_x_dt[t];
2204 const float dA = exp(shared_dA[t] * A0);
2205
2206 s = (s0 * dA) + (B[i0] * x_dt);
2207
2208 const float sumf = simd_sum(s * C[i0]);
2209
2210 if (tiisg == 0) {
2211 shared_sums[t*NW + sgitg] = sumf;
2212 }
2213
2214 // recurse
2215 s0 = s;
2216
2217 B += args.ns42;
2218 C += args.ns52;
2219 }
2220
2221 // Advance pointers for next batch
2222 x += sgptg * args.ns12;
2223 dt += sgptg * args.ns21;
2224
2225 threadgroup_barrier(mem_flags::mem_threadgroup);
2226
2227 const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]);
2228
2229 if (tiisg == 0 && i2 + sgitg < n_t) {
2230 y[sgitg*nh*nr] = sumf;
2231 }
2232
2233 y += sgptg*nh*nr;
2234 }
2235
2236 s_buff[i] = s;
2237}
2238
2239kernel void kernel_rwkv_wkv6_f32(
2240 device const float * k,
2241 device const float * v,
2242 device const float * r,
2243 device const float * tf,
2244 device const float * td,
2245 device const float * state_in,
2246 device float * dst,
2247 constant uint & B,
2248 constant uint & T,
2249 constant uint & C,
2250 constant uint & H,
2251 uint3 tgpig[[threadgroup_position_in_grid]],
2252 uint3 tpitg[[thread_position_in_threadgroup]],
2253 uint3 ntg[[threads_per_threadgroup]]) {
2254
2255 const uint head_size = 64; // TODO: support head_size = 128
2256 const uint batch_id = tgpig.x / H;
2257 const uint head_id = tgpig.x % H;
2258 const uint tid = tpitg.x;
2259
2260 if (batch_id >= B || head_id >= H) {
2261 return;
2262 }
2263
2264 const uint state_size = C * head_size;
2265 const uint n_seq_tokens = T / B;
2266
2267 threadgroup float _k[head_size];
2268 threadgroup float _r[head_size];
2269 threadgroup float _tf[head_size];
2270 threadgroup float _td[head_size];
2271
2272 float state[head_size];
2273
2274 for (uint i = 0; i < head_size; i++) {
2275 state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
2276 + i * head_size + tid];
2277 }
2278
2279 threadgroup_barrier(mem_flags::mem_threadgroup);
2280 _tf[tid] = tf[head_id * head_size + tid];
2281 threadgroup_barrier(mem_flags::mem_threadgroup);
2282
2283 const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
2284 const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
2285
2286 for (uint t = start_t; t < end_t; t += C) {
2287 threadgroup_barrier(mem_flags::mem_threadgroup);
2288 _k[tid] = k[t];
2289 _r[tid] = r[t];
2290 _td[tid] = td[t];
2291 threadgroup_barrier(mem_flags::mem_threadgroup);
2292
2293 const float v_val = v[t];
2294 float y = 0.0;
2295
2296 for (uint j = 0; j < head_size; j += 4) {
2297 float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
2298 float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
2299 float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
2300 float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
2301 float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2302
2303 float4 kv = k_vec * v_val;
2304
2305 float4 temp = tf_vec * kv + s_vec;
2306 y += dot(r_vec, temp);
2307
2308 s_vec = s_vec * td_vec + kv;
2309 state[j] = s_vec[0];
2310 state[j+1] = s_vec[1];
2311 state[j+2] = s_vec[2];
2312 state[j+3] = s_vec[3];
2313 }
2314
2315 dst[t] = y;
2316 }
2317
2318 for (uint i = 0; i < head_size; i++) {
2319 dst[T * C + batch_id * state_size + head_id * head_size * head_size
2320 + i * head_size + tid] = state[i];
2321 }
2322}
2323
2324kernel void kernel_rwkv_wkv7_f32(
2325 device const float * r,
2326 device const float * w,
2327 device const float * k,
2328 device const float * v,
2329 device const float * a,
2330 device const float * b,
2331 device const float * state_in,
2332 device float * dst,
2333 constant uint & B,
2334 constant uint & T,
2335 constant uint & C,
2336 constant uint & H,
2337 uint3 tgpig[[threadgroup_position_in_grid]],
2338 uint3 tpitg[[thread_position_in_threadgroup]],
2339 uint3 ntg[[threads_per_threadgroup]]) {
2340
2341 const uint head_size = 64; // TODO: support head_size = 128
2342 const uint batch_id = tgpig.x / H;
2343 const uint head_id = tgpig.x % H;
2344 const uint tid = tpitg.x;
2345
2346 if (batch_id >= B || head_id >= H) {
2347 return;
2348 }
2349
2350 const uint state_size = C * head_size;
2351 const uint n_seq_tokens = T / B;
2352
2353 threadgroup float _r[head_size];
2354 threadgroup float _w[head_size];
2355 threadgroup float _k[head_size];
2356 threadgroup float _a[head_size];
2357 threadgroup float _b[head_size];
2358
2359 float state[head_size];
2360
2361 for (uint i = 0; i < head_size; i++) {
2362 state[i] = state_in[batch_id * state_size + head_id * head_size * head_size
2363 + tid * head_size + i];
2364 }
2365
2366 const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid;
2367 const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid;
2368
2369 for (uint t = start_t; t < end_t; t += C) {
2370 threadgroup_barrier(mem_flags::mem_threadgroup);
2371 _r[tid] = r[t];
2372 _w[tid] = w[t];
2373 _k[tid] = k[t];
2374 _a[tid] = a[t];
2375 _b[tid] = b[t];
2376 threadgroup_barrier(mem_flags::mem_threadgroup);
2377
2378 const float v_val = v[t];
2379 float y = 0.0, sa = 0.0;
2380
2381 float4 sa_vec(0.0);
2382
2383 for (uint j = 0; j < head_size; j += 4) {
2384 float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
2385 float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2386 sa_vec += a_vec * s_vec;
2387 }
2388 sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3];
2389
2390 for (uint j = 0; j < head_size; j += 4) {
2391 float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
2392 float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
2393 float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
2394 float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
2395 float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
2396
2397 float4 kv = k_vec * v_val;
2398
2399 s_vec = s_vec * w_vec + kv + sa * b_vec;
2400 y += dot(s_vec, r_vec);
2401
2402 state[j] = s_vec[0];
2403 state[j+1] = s_vec[1];
2404 state[j+2] = s_vec[2];
2405 state[j+3] = s_vec[3];
2406 }
2407
2408 dst[t] = y;
2409 }
2410
2411 for (uint i = 0; i < head_size; i++) {
2412 dst[T * C + batch_id * state_size + head_id * head_size * head_size
2413 + tid * head_size + i] = state[i];
2414 }
2415}
2416
2417constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
2418constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
2419constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
2420
2421kernel void kernel_solve_tri_f32(
2422 constant ggml_metal_kargs_solve_tri & args,
2423 device const char * src0,
2424 device const char * src1,
2425 device char * dst,
2426 threadgroup char * shmem [[threadgroup(0)]],
2427 ushort3 tgpig[[threadgroup_position_in_grid]],
2428 ushort sgitg[[simdgroup_index_in_threadgroup]],
2429 ushort tiisg[[thread_index_in_simdgroup]],
2430 ushort3 ntg[[threads_per_threadgroup]]) {
2431 constexpr short NW = N_SIMDWIDTH;
2432
2433 const short NSG = FC_solve_tri_nsg;
2434 const short N = FC_solve_tri_n;
2435 const short K = FC_solve_tri_k;
2436 const short NP = PAD2(N, NW);
2437
2438 const int32_t ne02 = args.ne02;
2439 const int32_t ne03 = args.ne03;
2440
2441 const int32_t i03 = tgpig.z;
2442 const int32_t i02 = tgpig.y;
2443 const int32_t i01 = tgpig.x*NSG + sgitg;
2444
2445 threadgroup float * sh0 = (threadgroup float *) shmem;
2446
2447 device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N;
2448 device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01;
2449 device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01;
2450
2451 for (short rr = 0; rr < N; rr += NSG) {
2452 threadgroup_barrier(mem_flags::mem_threadgroup);
2453
2454 {
2455 threadgroup float * sh0_cur = sh0 + sgitg*NP;
2456
2457 for (short t = 0; t*NW < N; ++t) {
2458 const short idx = t*NW + tiisg;
2459 sh0_cur[idx] = src0_ptr[idx];
2460 }
2461
2462 src0_ptr += NSG*N;
2463 }
2464
2465 threadgroup_barrier(mem_flags::mem_threadgroup);
2466
2467 if (i01 >= args.ne10) {
2468 continue;
2469 }
2470
2471 for (short ir = 0; ir < NSG && rr + ir < N; ++ir) {
2472 const short r = rr + ir;
2473
2474 threadgroup float * sh0_cur = sh0 + ir*NP;
2475
2476 float sum = 0.0f;
2477
2478 for (short t = 0; t*NW < r; ++t) {
2479 const short idx = t*NW + tiisg;
2480 sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r);
2481 }
2482
2483 sum = simd_sum(sum);
2484
2485 if (tiisg == 0) {
2486 const float diag = sh0_cur[r];
2487
2488 dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag;
2489 }
2490 }
2491 }
2492}
2493
2494kernel void kernel_argmax_f32(
2495 constant ggml_metal_kargs_argmax & args,
2496 device const char * src0,
2497 device char * dst,
2498 threadgroup char * shmem [[threadgroup(0)]],
2499 uint tgpig[[threadgroup_position_in_grid]],
2500 uint tpitg[[thread_position_in_threadgroup]],
2501 uint sgitg[[simdgroup_index_in_threadgroup]],
2502 uint tiisg[[thread_index_in_simdgroup]],
2503 uint ntg[[threads_per_threadgroup]]) {
2504 device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01);
2505
2506 float lmax = -INFINITY;
2507 int32_t larg = -1;
2508
2509 for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
2510 if (x_row[i00] > lmax) {
2511 lmax = x_row[i00];
2512 larg = i00;
2513 }
2514 }
2515
2516 // find the argmax value in the block
2517 float max_val = simd_max(lmax);
2518 int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
2519
2520 device int32_t * dst_i32 = (device int32_t *) dst;
2521
2522 threadgroup float * shared_maxval = (threadgroup float *) shmem;
2523 threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH;
2524
2525 if (ntg > N_SIMDWIDTH) {
2526 if (sgitg == 0) {
2527 shared_maxval[tiisg] = -INFINITY;
2528 shared_argmax[tiisg] = -1;
2529 }
2530
2531 threadgroup_barrier(mem_flags::mem_threadgroup);
2532
2533 if (tiisg == 0) {
2534 shared_maxval[sgitg] = max_val;
2535 shared_argmax[sgitg] = arg_val;
2536 }
2537
2538 threadgroup_barrier(mem_flags::mem_threadgroup);
2539
2540 max_val = shared_maxval[tiisg];
2541 arg_val = shared_argmax[tiisg];
2542
2543 float max_val_reduced = simd_max(max_val);
2544 int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
2545
2546 dst_i32[tgpig] = arg_val_reduced;
2547
2548 return;
2549 }
2550
2551 dst_i32[tgpig] = arg_val;
2552}
2553
2554// F == 1 : norm (no fuse)
2555// F == 2 : norm + mul
2556// F == 3 : norm + mul + add
2557template <typename T, short F>
2558kernel void kernel_norm_fuse_impl(
2559 constant ggml_metal_kargs_norm & args,
2560 device const char * src0,
2561 device const char * src1_0,
2562 device const char * src1_1,
2563 device char * dst,
2564 threadgroup float * shmem_f32 [[threadgroup(0)]],
2565 uint3 tgpig[[threadgroup_position_in_grid]],
2566 ushort3 tpitg[[thread_position_in_threadgroup]],
2567 ushort sgitg[[simdgroup_index_in_threadgroup]],
2568 ushort tiisg[[thread_index_in_simdgroup]],
2569 ushort3 ntg[[threads_per_threadgroup]]) {
2570 if (sgitg == 0) {
2571 shmem_f32[tiisg] = 0.0f;
2572 }
2573
2574 const int i01 = tgpig.x;
2575 const int i02 = tgpig.y;
2576 const int i03 = tgpig.z;
2577
2578 device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2579
2580 device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2581 device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
2582
2583 T sumft(0.0f);
2584
2585 float sumf = 0.0f;
2586
2587 for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2588 sumft += x[i00];
2589 }
2590 sumf = dot(sumft, T(1.0f));
2591 sumf = simd_sum(sumf);
2592
2593 threadgroup_barrier(mem_flags::mem_threadgroup);
2594
2595 if (tiisg == 0) {
2596 shmem_f32[sgitg] = sumf;
2597 }
2598
2599 threadgroup_barrier(mem_flags::mem_threadgroup);
2600
2601 sumf = shmem_f32[tiisg];
2602 sumf = simd_sum(sumf);
2603
2604 const float mean = sumf/args.ne00;
2605
2606 device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2607
2608 sumf = 0.0f;
2609 for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2610 y[i00] = x[i00] - mean;
2611 sumf += dot(y[i00], y[i00]);
2612 }
2613 sumf = simd_sum(sumf);
2614
2615 threadgroup_barrier(mem_flags::mem_threadgroup);
2616
2617 if (tiisg == 0) {
2618 shmem_f32[sgitg] = sumf;
2619 }
2620
2621 threadgroup_barrier(mem_flags::mem_threadgroup);
2622
2623 sumf = shmem_f32[tiisg];
2624 sumf = simd_sum(sumf);
2625
2626 const float variance = sumf/args.ne00;
2627
2628 const float scale = 1.0f/sqrt(variance + args.eps);
2629 for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2630 if (F == 1) {
2631 y[i00] = (y[i00]*scale);
2632 }
2633 if (F == 2) {
2634 y[i00] = (y[i00]*scale)*f0[i00];
2635 }
2636 if (F == 3) {
2637 y[i00] = (y[i00]*scale)*f0[i00] + f1[i00];
2638 }
2639 }
2640}
2641
2642typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t;
2643
2644template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>;
2645template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>;
2646template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>;
2647
2648template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>;
2649template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>;
2650template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>;
2651
2652// F == 1 : rms_norm (no fuse)
2653// F == 2 : rms_norm + mul
2654// F == 3 : rms_norm + mul + add
2655template <typename T, short F>
2656kernel void kernel_rms_norm_fuse_impl(
2657 constant ggml_metal_kargs_norm & args,
2658 device const char * src0,
2659 device const char * src1_0,
2660 device const char * src1_1,
2661 device char * dst,
2662 threadgroup float * shmem_f32 [[threadgroup(0)]],
2663 uint3 tgpig[[threadgroup_position_in_grid]],
2664 ushort3 tpitg[[thread_position_in_threadgroup]],
2665 ushort sgitg[[simdgroup_index_in_threadgroup]],
2666 ushort tiisg[[thread_index_in_simdgroup]],
2667 ushort3 ntg[[threads_per_threadgroup]]) {
2668 if (sgitg == 0) {
2669 shmem_f32[tiisg] = 0.0f;
2670 }
2671
2672 const int i01 = tgpig.x;
2673 const int i02 = tgpig.y;
2674 const int i03 = tgpig.z;
2675
2676 device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]);
2677
2678 device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]);
2679 device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]);
2680
2681 float sumf = 0.0f;
2682
2683 // parallel sum
2684 for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2685 sumf += dot(x[i00], x[i00]);
2686 }
2687 sumf = simd_sum(sumf);
2688
2689 threadgroup_barrier(mem_flags::mem_threadgroup);
2690
2691 if (tiisg == 0) {
2692 shmem_f32[sgitg] = sumf;
2693 }
2694
2695 threadgroup_barrier(mem_flags::mem_threadgroup);
2696
2697 sumf = shmem_f32[tiisg];
2698 sumf = simd_sum(sumf);
2699
2700 const float mean = sumf/args.ne00;
2701 const float scale = 1.0f/sqrt(mean + args.eps);
2702
2703 device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2704 for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) {
2705 if (F == 1) {
2706 y[i00] = (x[i00]*scale);
2707 }
2708 if (F == 2) {
2709 y[i00] = (x[i00]*scale)*f0[i00];
2710 }
2711 if (F == 3) {
2712 y[i00] = (x[i00]*scale)*f0[i00] + f1[i00];
2713 }
2714 }
2715}
2716
2717typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t;
2718
2719template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>;
2720template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>;
2721template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>;
2722
2723template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>;
2724template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>;
2725template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>;
2726
2727template <typename T0, typename T>
2728kernel void kernel_l2_norm_impl(
2729 constant ggml_metal_kargs_l2_norm & args,
2730 device const char * src0,
2731 device char * dst,
2732 threadgroup float * shmem_f32 [[threadgroup(0)]],
2733 uint3 tgpig[[threadgroup_position_in_grid]],
2734 ushort3 tpitg[[thread_position_in_threadgroup]],
2735 ushort sgitg[[simdgroup_index_in_threadgroup]],
2736 ushort tiisg[[thread_index_in_simdgroup]],
2737 ushort3 ntg[[threads_per_threadgroup]]) {
2738 const int i03 = tgpig.z;
2739 const int i02 = tgpig.y;
2740 const int i01 = tgpig.x;
2741
2742 if (sgitg == 0) {
2743 shmem_f32[tiisg] = 0.0f;
2744 }
2745
2746 device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
2747 device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1);
2748
2749 float sumf = 0.0f;
2750
2751 // parallel sum
2752 for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2753 sumf += dot(x[i00], x[i00]);
2754 }
2755 sumf = simd_sum(sumf);
2756
2757 threadgroup_barrier(mem_flags::mem_threadgroup);
2758
2759 if (tiisg == 0) {
2760 shmem_f32[sgitg] = sumf;
2761 }
2762
2763 threadgroup_barrier(mem_flags::mem_threadgroup);
2764
2765 sumf = shmem_f32[tiisg];
2766 sumf = simd_sum(sumf);
2767
2768 const float scale = 1.0f/sqrt(max(sumf, args.eps));
2769
2770 for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
2771 y[i00] = x[i00] * scale;
2772 }
2773}
2774
2775typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t;
2776
2777template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>;
2778template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>;
2779
2780kernel void kernel_group_norm_f32(
2781 constant ggml_metal_kargs_group_norm & args,
2782 device const float * src0,
2783 device float * dst,
2784 threadgroup float * buf [[threadgroup(0)]],
2785 uint tgpig[[threadgroup_position_in_grid]],
2786 uint tpitg[[thread_position_in_threadgroup]],
2787 uint sgitg[[simdgroup_index_in_threadgroup]],
2788 uint tiisg[[thread_index_in_simdgroup]],
2789 uint ntg[[threads_per_threadgroup]]) {
2790 const int64_t ne = args.ne00*args.ne01*args.ne02;
2791 const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp);
2792
2793 int start = tgpig * gs;
2794 int end = start + gs;
2795
2796 start += tpitg;
2797
2798 if (end >= ne) {
2799 end = ne;
2800 }
2801
2802 float tmp = 0.0f; // partial sum for thread in warp
2803
2804 for (int j = start; j < end; j += ntg) {
2805 tmp += src0[j];
2806 }
2807
2808 threadgroup_barrier(mem_flags::mem_threadgroup);
2809 tmp = simd_sum(tmp);
2810 if (ntg > N_SIMDWIDTH) {
2811 if (sgitg == 0) {
2812 buf[tiisg] = 0.0f;
2813 }
2814
2815 threadgroup_barrier(mem_flags::mem_threadgroup);
2816
2817 if (tiisg == 0) {
2818 buf[sgitg] = tmp;
2819 }
2820
2821 threadgroup_barrier(mem_flags::mem_threadgroup);
2822
2823 tmp = buf[tiisg];
2824 tmp = simd_sum(tmp);
2825 }
2826
2827 const float mean = tmp / gs;
2828 tmp = 0.0f;
2829
2830 for (int j = start; j < end; j += ntg) {
2831 float xi = src0[j] - mean;
2832 dst[j] = xi;
2833 tmp += xi * xi;
2834 }
2835
2836 tmp = simd_sum(tmp);
2837 if (ntg > N_SIMDWIDTH) {
2838 if (sgitg == 0) {
2839 buf[tiisg] = 0.0f;
2840 }
2841
2842 threadgroup_barrier(mem_flags::mem_threadgroup);
2843
2844 if (tiisg == 0) {
2845 buf[sgitg] = tmp;
2846 }
2847
2848 threadgroup_barrier(mem_flags::mem_threadgroup);
2849
2850 tmp = buf[tiisg];
2851 tmp = simd_sum(tmp);
2852 }
2853
2854 const float variance = tmp / gs;
2855 const float scale = 1.0f/sqrt(variance + args.eps);
2856 for (int j = start; j < end; j += ntg) {
2857 dst[j] *= scale;
2858 }
2859}
2860
2861// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
2862// il indicates where the q4 quants begin (0 or QK4_0/4)
2863// we assume that the yl's have been multiplied with the appropriate scale factor
2864// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2865inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
2866 float d = qb_curr->d;
2867
2868 float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2869
2870 device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
2871
2872 for (int i = 0; i < 8; i += 2) {
2873 acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
2874 acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
2875 acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
2876 acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
2877 }
2878
2879 return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
2880}
2881
2882// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
2883// il indicates where the q4 quants begin (0 or QK4_0/4)
2884// we assume that the yl's have been multiplied with the appropriate scale factor
2885// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2886inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
2887 float d = qb_curr->d;
2888 float m = qb_curr->m;
2889
2890 float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2891
2892 device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
2893
2894 for (int i = 0; i < 8; i+=2) {
2895 acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
2896 acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
2897 acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
2898 acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
2899 }
2900
2901 return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
2902}
2903
2904// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
2905// il indicates where the q5 quants begin (0 or QK5_0/4)
2906// we assume that the yl's have been multiplied with the appropriate scale factor
2907// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2908inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
2909 float d = qb_curr->d;
2910
2911 float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2912
2913 device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
2914 const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
2915
2916 for (int i = 0; i < 8; i+=2) {
2917 acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
2918 acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
2919 acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
2920 acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
2921 }
2922
2923 return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
2924}
2925
2926// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
2927// il indicates where the q5 quants begin (0 or QK5_1/4)
2928// we assume that the yl's have been multiplied with the appropriate scale factor
2929// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
2930inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
2931 float d = qb_curr->d;
2932 float m = qb_curr->m;
2933
2934 float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
2935
2936 device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
2937 const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
2938
2939 for (int i = 0; i < 8; i+=2) {
2940 acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
2941 acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
2942 acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
2943 acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
2944 }
2945
2946 return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
2947}
2948
2949template<short NR0>
2950static inline void helper_mv_reduce_and_write(
2951 device float * dst_f32,
2952 float sumf[NR0],
2953 const int r0,
2954 const int ne01,
2955 ushort tiisg,
2956 ushort sgitg,
2957 threadgroup char * shmem) {
2958 constexpr short NW = N_SIMDWIDTH;
2959
2960 threadgroup float * shmem_f32[NR0];
2961
2962 for (short row = 0; row < NR0; ++row) {
2963 shmem_f32[row] = (threadgroup float *) shmem + NW*row;
2964
2965 if (sgitg == 0) {
2966 shmem_f32[row][tiisg] = 0.0f;
2967 }
2968
2969 sumf[row] = simd_sum(sumf[row]);
2970 }
2971
2972 threadgroup_barrier(mem_flags::mem_threadgroup);
2973
2974 for (short row = 0; row < NR0; ++row) {
2975 if (tiisg == 0) {
2976 shmem_f32[row][sgitg] = sumf[row];
2977 }
2978 }
2979
2980 threadgroup_barrier(mem_flags::mem_threadgroup);
2981
2982 for (short row = 0; row < NR0 && r0 + row < ne01; ++row) {
2983 float tot = simd_sum(shmem_f32[row][tiisg]);
2984
2985 if (tiisg == 0 && sgitg == 0) {
2986 dst_f32[r0 + row] = tot;
2987 }
2988 }
2989}
2990
2991constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
2992constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
2993
2994template<typename block_q_type, short NR0, typename args_t>
2995void mul_vec_q_n_f32_impl(
2996 args_t args,
2997 device const char * src0,
2998 device const char * src1,
2999 device char * dst,
3000 threadgroup char * shmem,
3001 uint3 tgpig,
3002 ushort tiisg,
3003 ushort sgitg) {
3004 const short NSG = FC_mul_mv_nsg;
3005
3006 constexpr short NW = N_SIMDWIDTH;
3007 constexpr short NQ = 16;
3008
3009 const int nb = args.ne00/QK4_0;
3010
3011 const int r0 = (tgpig.x*NSG + sgitg)*NR0;
3012 //const int r0 = tgpig.x*NR0;
3013 const int r1 = tgpig.y;
3014 const int im = tgpig.z;
3015
3016 const uint i12 = im%args.ne12;
3017 const uint i13 = im/args.ne12;
3018
3019 //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3020 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3021
3022 //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
3023 device const float * y = (device const float *) (src1 + offset1);
3024
3025 // pointers to src0 rows
3026 device const block_q_type * ax[NR0];
3027 FOR_UNROLL (int row = 0; row < NR0; ++row) {
3028 const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3029
3030 ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
3031 }
3032
3033 float sumf[NR0] = {0.f};
3034
3035 const short ix = (tiisg/(NW/NQ));
3036 const short il = (tiisg%(NW/NQ))*8;
3037
3038 //const int ib0 = sgitg*NQ + ix;
3039 const int ib0 = ix;
3040
3041 float yl[16]; // src1 vector cache
3042
3043 //device const float * yb = y + ix*QK4_0 + il;
3044 device const float * yb = y + ib0*QK4_0 + il;
3045
3046 // each thread in a SIMD group deals with half a block.
3047 //for (int ib = ib0; ib < nb; ib += NSG*NQ) {
3048 for (int ib = ib0; ib < nb; ib += NQ) {
3049 float sumy[2] = { 0.f, 0.f };
3050
3051 FOR_UNROLL (short i = 0; i < 8; i += 2) {
3052 sumy[0] += yb[i + 0] + yb[i + 1];
3053 yl[i + 0] = yb[i + 0];
3054 yl[i + 1] = yb[i + 1]/256.f;
3055
3056 sumy[1] += yb[i + 16] + yb[i + 17];
3057 yl[i + 8] = yb[i + 16]/16.f;
3058 yl[i + 9] = yb[i + 17]/4096.f;
3059 }
3060
3061 FOR_UNROLL (short row = 0; row < NR0; row++) {
3062 sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
3063 }
3064
3065 yb += QK4_0 * 16;
3066 //yb += NSG*NQ*QK4_0;
3067 }
3068
3069 device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
3070
3071 //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3072
3073 for (int row = 0; row < NR0; ++row) {
3074 const float tot = simd_sum(sumf[row]);
3075
3076 if (tiisg == 0 && r0 + row < args.ne01) {
3077 dst_f32[r0 + row] = tot;
3078 }
3079 }
3080}
3081
3082kernel void kernel_mul_mv_q4_0_f32(
3083 constant ggml_metal_kargs_mul_mv & args,
3084 device const char * src0,
3085 device const char * src1,
3086 device char * dst,
3087 threadgroup char * shmem [[threadgroup(0)]],
3088 uint3 tgpig[[threadgroup_position_in_grid]],
3089 ushort tiisg[[thread_index_in_simdgroup]],
3090 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3091 mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3092}
3093
3094kernel void kernel_mul_mv_q4_1_f32(
3095 constant ggml_metal_kargs_mul_mv & args,
3096 device const char * src0,
3097 device const char * src1,
3098 device char * dst,
3099 threadgroup char * shmem [[threadgroup(0)]],
3100 uint3 tgpig[[threadgroup_position_in_grid]],
3101 ushort tiisg[[thread_index_in_simdgroup]],
3102 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3103 mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3104}
3105
3106kernel void kernel_mul_mv_q5_0_f32(
3107 constant ggml_metal_kargs_mul_mv & args,
3108 device const char * src0,
3109 device const char * src1,
3110 device char * dst,
3111 threadgroup char * shmem [[threadgroup(0)]],
3112 uint3 tgpig[[threadgroup_position_in_grid]],
3113 ushort tiisg[[thread_index_in_simdgroup]],
3114 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3115 mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3116}
3117
3118kernel void kernel_mul_mv_q5_1_f32(
3119 constant ggml_metal_kargs_mul_mv & args,
3120 device const char * src0,
3121 device const char * src1,
3122 device char * dst,
3123 threadgroup char * shmem [[threadgroup(0)]],
3124 uint3 tgpig[[threadgroup_position_in_grid]],
3125 ushort tiisg[[thread_index_in_simdgroup]],
3126 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3127 mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3128}
3129
3130template<short NR0, typename args_t>
3131void kernel_mul_mv_q8_0_f32_impl(
3132 args_t args,
3133 device const char * src0,
3134 device const char * src1,
3135 device char * dst,
3136 threadgroup char * shmem,
3137 uint3 tgpig,
3138 ushort tiisg,
3139 ushort sgitg) {
3140 const short NSG = FC_mul_mv_nsg;
3141
3142 constexpr short NW = N_SIMDWIDTH;
3143 constexpr short NQ = 8;
3144
3145 const int nb = args.ne00/QK8_0;
3146
3147 const int r0 = tgpig.x*NR0;
3148 const int r1 = tgpig.y;
3149 const int im = tgpig.z;
3150
3151 const uint i12 = im%args.ne12;
3152 const uint i13 = im/args.ne12;
3153
3154 //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3155 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3156
3157 //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
3158 device const float * y = (device const float *) (src1 + offset1);
3159
3160 // pointers to src0 rows
3161 device const block_q8_0 * ax[NR0];
3162 FOR_UNROLL (short row = 0; row < NR0; ++row) {
3163 const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3164
3165 ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
3166 }
3167
3168 float sumf[NR0] = { 0.f };
3169
3170 const short ix = tiisg/(NW/NQ);
3171 const short il = tiisg%(NW/NQ);
3172
3173 const int ib0 = sgitg*NQ + ix;
3174
3175 float yl[NQ];
3176
3177 device const float * yb = y + ib0*QK8_0 + il*NQ;
3178
3179 // each thread in a SIMD group deals with NQ quants at a time
3180 for (int ib = ib0; ib < nb; ib += NSG*NQ) {
3181 for (short i = 0; i < NQ; ++i) {
3182 yl[i] = yb[i];
3183 }
3184
3185 for (short row = 0; row < NR0; row++) {
3186 device const int8_t * qs = ax[row][ib].qs + il*NQ;
3187
3188 float sumq = 0.f;
3189 FOR_UNROLL (short i = 0; i < NQ; ++i) {
3190 sumq += qs[i] * yl[i];
3191 }
3192
3193 sumf[row] += sumq*ax[row][ib].d;
3194 }
3195
3196 yb += NSG*NQ*QK8_0;
3197 }
3198
3199 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3200
3201 helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3202}
3203
3204[[host_name("kernel_mul_mv_q8_0_f32")]]
3205kernel void kernel_mul_mv_q8_0_f32(
3206 constant ggml_metal_kargs_mul_mv & args,
3207 device const char * src0,
3208 device const char * src1,
3209 device char * dst,
3210 threadgroup char * shmem [[threadgroup(0)]],
3211 uint3 tgpig[[threadgroup_position_in_grid]],
3212 ushort tiisg[[thread_index_in_simdgroup]],
3213 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3214 kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3215}
3216
3217// mat-vec kernel processing in chunks of float4
3218// chpb - chunks per quantization block
3219template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
3220void kernel_mul_mv_ext_q4_f32_impl(
3221 constant ggml_metal_kargs_mul_mv_ext & args,
3222 device const char * src0,
3223 device const char * src1,
3224 device char * dst,
3225 uint3 tgpig[[threadgroup_position_in_grid]],
3226 ushort tiisg[[thread_index_in_simdgroup]],
3227 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3228 const short NSG = FC_mul_mv_nsg;
3229 const short nxpsg = FC_mul_mv_nxpsg;
3230
3231 const short chpt = 4; // chunks per thread
3232
3233 //const short nxpsg = (32);
3234 const short nypsg = (32/nxpsg);
3235
3236 const short tx = tiisg%nxpsg;
3237 const short ty = tiisg/nxpsg;
3238
3239 const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
3240 const int i11 = tgpig.y*r1ptg;
3241 const int i1m = tgpig.z;
3242
3243 const int i12 = i1m%args.ne12;
3244 const int i13 = i1m/args.ne12;
3245
3246 const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3247 const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3248
3249 device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
3250
3251 device const float4 * y4[r1ptg];
3252
3253 for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
3254 y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
3255 }
3256
3257 float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
3258
3259 short cch = tx%chpb; // current chunk index
3260
3261 for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
3262 float4 lx[chpt];
3263
3264#pragma unroll(chpt)
3265 for (short ch = 0; ch < chpt; ++ch) {
3266 deq_t4(xq, cch, lx[ch]);
3267
3268 cch += nxpsg;
3269 if (cch >= chpb) {
3270 xq += cch/chpb;
3271 cch %= chpb;
3272 }
3273 }
3274
3275#pragma unroll(chpt)
3276 for (short ch = 0; ch < chpt; ++ch) {
3277#pragma unroll(r1ptg)
3278 for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3279 sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
3280 }
3281 }
3282
3283#pragma unroll(r1ptg)
3284 for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3285 y4[ir1] += chpt*nxpsg;
3286 }
3287 }
3288
3289 // reduce only the threads in each row
3290 for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3291 if (nxpsg >= 32) {
3292 sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
3293 }
3294 if (nxpsg >= 16) {
3295 sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
3296 }
3297 if (nxpsg >= 8) {
3298 sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
3299 }
3300 if (nxpsg >= 4) {
3301 sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
3302 }
3303 if (nxpsg >= 2) {
3304 sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
3305 }
3306
3307 //sumf[ir1] = simd_sum(sumf[ir1]);
3308 }
3309
3310 if (tx == 0) {
3311 for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
3312 device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
3313
3314 if (i01 < args.ne01) {
3315 dst_f32[i01] = sumf[ir1];
3316 }
3317 }
3318 }
3319}
3320
3321// mat-vec kernel processing in chunks of float4x4
3322template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
3323void kernel_mul_mv_ext_q4x4_f32_impl(
3324 constant ggml_metal_kargs_mul_mv_ext & args,
3325 device const char * src0,
3326 device const char * src1,
3327 device char * dst,
3328 uint3 tgpig[[threadgroup_position_in_grid]],
3329 ushort tiisg[[thread_index_in_simdgroup]],
3330 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3331 const short NSG = FC_mul_mv_nsg;
3332 const short nxpsg = FC_mul_mv_nxpsg;
3333
3334 const short chpt = 1;
3335
3336 //const short nxpsg = (32);
3337 const short nypsg = (32/nxpsg);
3338
3339 const short tx = tiisg%nxpsg;
3340 const short ty = tiisg/nxpsg;
3341
3342 const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
3343 const int i11 = tgpig.y*r1ptg;
3344 const int i1m = tgpig.z;
3345
3346 const int i12 = i1m%args.ne12;
3347 const int i13 = i1m/args.ne12;
3348
3349 const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3350 const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3351
3352 device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
3353
3354 device const float4x4 * y4x4[r1ptg];
3355
3356 for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
3357 y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
3358 }
3359
3360 float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
3361
3362 short cch = tx%chpb;
3363
3364 for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
3365 float4x4 lx[chpt];
3366
3367#pragma unroll(chpt)
3368 for (short ch = 0; ch < chpt; ++ch) {
3369 deq_t4x4(xq, cch, lx[ch]);
3370
3371 cch += nxpsg;
3372 if (cch >= chpb) {
3373 xq += cch/chpb;
3374 cch %= chpb;
3375 }
3376 }
3377
3378#pragma unroll(chpt)
3379 for (short ch = 0; ch < chpt; ++ch) {
3380#pragma unroll(r1ptg)
3381 for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3382 sumf[ir1] +=
3383 dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
3384 dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
3385 dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
3386 dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
3387
3388 }
3389 }
3390
3391#pragma unroll(r1ptg)
3392 for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3393 y4x4[ir1] += chpt*nxpsg;
3394 }
3395 }
3396
3397 for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
3398 if (nxpsg >= 32) {
3399 sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
3400 }
3401 if (nxpsg >= 16) {
3402 sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
3403 }
3404 if (nxpsg >= 8) {
3405 sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
3406 }
3407 if (nxpsg >= 4) {
3408 sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
3409 }
3410 if (nxpsg >= 2) {
3411 sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
3412 }
3413
3414 //sumf[ir1] = simd_sum(sumf[ir1]);
3415 }
3416
3417 if (tx == 0) {
3418 for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
3419 device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
3420
3421 if (i01 < args.ne01) {
3422 dst_f32[i01] = sumf[ir1];
3423 }
3424 }
3425 }
3426}
3427
3428// dispatchers needed for compile-time nxpsg
3429// epb - elements per quantization block
3430template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
3431kernel void kernel_mul_mv_ext_q4_f32_disp(
3432 constant ggml_metal_kargs_mul_mv_ext & args,
3433 device const char * src0,
3434 device const char * src1,
3435 device char * dst,
3436 uint3 tgpig[[threadgroup_position_in_grid]],
3437 ushort tiisg[[thread_index_in_simdgroup]],
3438 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3439 kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
3440}
3441
3442template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
3443kernel void kernel_mul_mv_ext_q4x4_f32_disp(
3444 constant ggml_metal_kargs_mul_mv_ext & args,
3445 device const char * src0,
3446 device const char * src1,
3447 device char * dst,
3448 uint3 tgpig[[threadgroup_position_in_grid]],
3449 ushort tiisg[[thread_index_in_simdgroup]],
3450 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3451 kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
3452}
3453
3454typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
3455typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
3456
3457template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
3458template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
3459template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
3460template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
3461
3462template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
3463template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
3464template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
3465template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>;
3466
3467template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>;
3468template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>;
3469template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>;
3470template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>;
3471
3472template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>;
3473template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>;
3474template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>;
3475template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>;
3476
3477template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>;
3478template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>;
3479template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>;
3480template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>;
3481
3482template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>;
3483template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>;
3484template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>;
3485template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>;
3486
3487template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>;
3488template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>;
3489template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
3490template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
3491
3492template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
3493template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
3494template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
3495template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
3496
3497template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3498template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3499template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3500template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
3501
3502template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
3503template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
3504template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
3505template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
3506
3507template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
3508template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
3509template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
3510template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
3511
3512template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
3513template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
3514template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
3515template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
3516
3517template<typename T0, typename T1, short NR0, typename args_t>
3518void kernel_mul_mv_t_t_impl(
3519 args_t args,
3520 device const char * src0,
3521 device const char * src1,
3522 device char * dst,
3523 threadgroup char * shmem,
3524 uint3 tgpig,
3525 ushort tiisg,
3526 ushort sgitg) {
3527 const short NSG = FC_mul_mv_nsg;
3528
3529 constexpr short NW = N_SIMDWIDTH;
3530 constexpr short NB = 32;
3531 constexpr short NF = 8;
3532
3533 const int nb = args.ne00/NB;
3534
3535 const int r0 = tgpig.x*NR0;
3536 const int r1 = tgpig.y;
3537 const int im = tgpig.z;
3538
3539 const uint i12 = im%args.ne12;
3540 const uint i13 = im/args.ne12;
3541
3542 //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3543 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3544
3545 //device const T0 * x = (device const T0 *) (src0 + offset0);
3546 device const T1 * y = (device const T1 *) (src1 + offset1);
3547
3548 // pointers to src0 rows
3549 device const T0 * ax [NR0];
3550 FOR_UNROLL (short row = 0; row < NR0; ++row) {
3551 const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3552
3553 ax[row] = (device const T0 *) ((device char *) src0 + offset0);
3554 }
3555
3556 float sumf[NR0] = { 0.f };
3557
3558 const short ix = tiisg/(NW/NF);
3559 const short il = tiisg%(NW/NF);
3560
3561 const int ib0 = sgitg*NF + ix;
3562
3563 T1 yl[NF];
3564
3565 device const T1 * yb = y + (ib0*NB + il*NF);
3566
3567 for (int ib = ib0; ib < nb; ib += NSG*NF) {
3568 for (short i = 0; i < NF; ++i) {
3569 yl[i] = yb[i];
3570 }
3571
3572 for (short row = 0; row < NR0; row++) {
3573 device const T0 * xb = ax[row] + (ib*NB + il*NF);
3574
3575 float sumq = 0.f;
3576 FOR_UNROLL (short i = 0; i < NF; ++i) {
3577 sumq += xb[i] * yl[i];
3578 }
3579
3580 sumf[row] += sumq;
3581 }
3582
3583 yb += NSG*NF*NW;
3584 }
3585
3586 for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
3587 for (short row = 0; row < NR0; row++) {
3588 sumf[row] += ax[row][i] * y[i];
3589 }
3590 }
3591
3592 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3593
3594 helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3595}
3596
3597template<typename T0, typename T1, typename args_t>
3598void kernel_mul_mv_t_t_disp(
3599 args_t args,
3600 device const char * src0,
3601 device const char * src1,
3602 device char * dst,
3603 threadgroup char * shmem,
3604 uint3 tgpig,
3605 ushort tiisg,
3606 ushort sgitg) {
3607 switch (args.nr0) {
3608 //case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3609 case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3610 //case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3611 //case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3612 }
3613}
3614
3615template<typename T0, typename T1>
3616kernel void kernel_mul_mv_t_t(
3617 constant ggml_metal_kargs_mul_mv & args,
3618 device const char * src0,
3619 device const char * src1,
3620 device char * dst,
3621 threadgroup char * shmem [[threadgroup(0)]],
3622 uint3 tgpig[[threadgroup_position_in_grid]],
3623 ushort tiisg[[thread_index_in_simdgroup]],
3624 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3625 kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3626}
3627
3628typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t;
3629
3630template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>;
3631template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>;
3632template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>;
3633#if defined(GGML_METAL_HAS_BF16)
3634template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>;
3635template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>;
3636#endif
3637
3638template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
3639void kernel_mul_mv_t_t_4_impl(
3640 args_t args,
3641 device const char * src0,
3642 device const char * src1,
3643 device char * dst,
3644 threadgroup char * shmem,
3645 uint3 tgpig,
3646 ushort tiisg,
3647 ushort sgitg) {
3648 const short NSG = FC_mul_mv_nsg;
3649
3650 constexpr short NW = N_SIMDWIDTH;
3651 constexpr short NB = 32;
3652 constexpr short NF = 16;
3653 constexpr short NF4 = NF/4;
3654
3655 const int nb = args.ne00/NB;
3656
3657 const int r0 = tgpig.x*NR0;
3658 const int r1 = tgpig.y;
3659 const int im = tgpig.z;
3660
3661 const uint i12 = im%args.ne12;
3662 const uint i13 = im/args.ne12;
3663
3664 //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3665 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3666
3667 device const T1 * y = (device const T1 *) (src1 + offset1);
3668 device const T14 * y4 = (device const T14 *) (src1 + offset1);
3669
3670 // pointers to src0 rows
3671 device const T0 * ax [NR0];
3672 device const T04 * ax4[NR0];
3673 FOR_UNROLL (short row = 0; row < NR0; ++row) {
3674 const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3675
3676 ax [row] = (device const T0 *) ((device char *) src0 + offset0);
3677 ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
3678 }
3679
3680 float sumf[NR0] = { 0.f };
3681
3682 const short ix = tiisg/(NW/NF);
3683 const short il = tiisg%(NW/NF);
3684
3685 const int ib0 = sgitg*NF + ix;
3686
3687 T14 yl4[NF4];
3688
3689 device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
3690
3691 for (int ib = ib0; ib < nb; ib += NSG*NF) {
3692 for (short i = 0; i < NF4; ++i) {
3693 yl4[i] = yb4[i];
3694 }
3695
3696 for (short row = 0; row < NR0; row++) {
3697 device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
3698
3699 float sumq = 0.f;
3700 FOR_UNROLL (short i = 0; i < NF4; ++i) {
3701 sumq += dot(float4(xb4[i]), float4(yl4[i]));
3702 }
3703
3704 sumf[row] += sumq;
3705 }
3706
3707 yb4 += NSG*NF*NW/4;
3708 }
3709
3710 for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
3711 for (short row = 0; row < NR0; row++) {
3712 sumf[row] += ax[row][i] * y[i];
3713 }
3714 }
3715
3716 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
3717
3718 helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
3719}
3720
3721template<typename T0, typename T04, typename T1, typename T14, typename args_t>
3722void kernel_mul_mv_t_t_4_disp(
3723 args_t args,
3724 device const char * src0,
3725 device const char * src1,
3726 device char * dst,
3727 threadgroup char * shmem,
3728 uint3 tgpig,
3729 ushort tiisg,
3730 ushort sgitg) {
3731 switch (args.nr0) {
3732 //case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3733 case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3734 //case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3735 //case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break;
3736 };
3737}
3738
3739template<typename T0, typename T04, typename T1, typename T14>
3740kernel void kernel_mul_mv_t_t_4(
3741 constant ggml_metal_kargs_mul_mv & args,
3742 device const char * src0,
3743 device const char * src1,
3744 device char * dst,
3745 threadgroup char * shmem [[threadgroup(0)]],
3746 uint3 tgpig[[threadgroup_position_in_grid]],
3747 ushort tiisg[[thread_index_in_simdgroup]],
3748 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3749 kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
3750}
3751
3752typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4;
3753
3754template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>;
3755template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>;
3756template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>;
3757#if defined(GGML_METAL_HAS_BF16)
3758template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>;
3759template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>;
3760#endif
3761
3762template<typename T0, typename T1, typename args_t>
3763void kernel_mul_mv_t_t_short_impl(
3764 args_t args,
3765 device const char * src0,
3766 device const char * src1,
3767 device char * dst,
3768 uint3 tgpig,
3769 ushort tiisg) {
3770 const int r0 = tgpig.x*32 + tiisg;
3771 const int r1 = tgpig.y;
3772 const int im = tgpig.z;
3773
3774 if (r0 >= args.ne01) {
3775 return;
3776 }
3777
3778 const uint i12 = im%args.ne12;
3779 const uint i13 = im/args.ne12;
3780
3781 const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
3782
3783 device const T0 * x = (device const T0 *) (src0 + offset0);
3784
3785 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
3786
3787 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
3788
3789 device const T1 * y = (device const T1 *) (src1 + offset1);
3790
3791 float res = 0.0f;
3792
3793 for (int i = 0; i < args.ne00; ++i) {
3794 res += (float) x[i] * (float) y[i];
3795 }
3796
3797 dst_f32[(uint64_t)r1*args.ne0 + r0] = res;
3798}
3799
3800template<typename T0, typename T1>
3801kernel void kernel_mul_mv_t_t_short(
3802 constant ggml_metal_kargs_mul_mv & args,
3803 device const char * src0,
3804 device const char * src1,
3805 device char * dst,
3806 uint3 tgpig[[threadgroup_position_in_grid]],
3807 ushort tiisg[[thread_index_in_simdgroup]]) {
3808 kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>(
3809 args,
3810 src0,
3811 src1,
3812 dst,
3813 tgpig,
3814 tiisg);
3815}
3816
3817typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t;
3818
3819template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>;
3820template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>;
3821template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>;
3822#if defined(GGML_METAL_HAS_BF16)
3823template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>;
3824template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>;
3825#endif
3826
3827constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]];
3828
3829static float rope_yarn_ramp(const float low, const float high, const int i0) {
3830 const float y = (i0 / 2 - low) / max(0.001f, high - low);
3831 return 1.0f - min(1.0f, max(0.0f, y));
3832}
3833
3834// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
3835// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
3836static void rope_yarn(
3837 float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
3838 thread float * cos_theta, thread float * sin_theta) {
3839 // Get n-d rotational scaling corrected for extrapolation
3840 float theta_interp = freq_scale * theta_extrap;
3841 float theta = theta_interp;
3842 if (ext_factor != 0.0f) {
3843 float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
3844 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
3845
3846 // Get n-d magnitude scaling corrected for interpolation
3847 mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
3848 }
3849 *cos_theta = cos(theta) * mscale;
3850 *sin_theta = sin(theta) * mscale;
3851}
3852
3853// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
3854// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
3855static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
3856 return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
3857}
3858
3859static void rope_yarn_corr_dims(
3860 int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]
3861) {
3862 // start and end correction dims
3863 dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base)));
3864 dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)));
3865}
3866
3867template<typename T>
3868kernel void kernel_rope_norm(
3869 constant ggml_metal_kargs_rope & args,
3870 device const char * src0,
3871 device const char * src1,
3872 device const char * src2,
3873 device char * dst,
3874 ushort tiitg[[thread_index_in_threadgroup]],
3875 ushort3 tptg [[threads_per_threadgroup]],
3876 uint3 tgpig[[threadgroup_position_in_grid]]) {
3877 const int i3 = tgpig[2];
3878 const int i2 = tgpig[1];
3879 const int i1 = tgpig[0];
3880
3881 float corr_dims[2];
3882 rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
3883
3884 device const int32_t * pos = (device const int32_t *) src1;
3885
3886 const float theta_base = (float) pos[i2];
3887 const float inv_ndims = -1.f/args.n_dims;
3888
3889 float cos_theta;
3890 float sin_theta;
3891
3892 for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
3893 if (i0 < args.n_dims) {
3894 const int ic = i0/2;
3895
3896 const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
3897
3898 const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
3899
3900 rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
3901
3902 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
3903 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
3904
3905 const float x0 = src[0];
3906 const float x1 = src[1];
3907
3908 dst_data[0] = x0*cos_theta - x1*sin_theta;
3909 dst_data[1] = x0*sin_theta + x1*cos_theta;
3910 } else {
3911 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
3912 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
3913
3914 dst_data[0] = src[0];
3915 dst_data[1] = src[1];
3916 }
3917 }
3918}
3919
3920template<typename T>
3921kernel void kernel_rope_neox(
3922 constant ggml_metal_kargs_rope & args,
3923 device const char * src0,
3924 device const char * src1,
3925 device const char * src2,
3926 device char * dst,
3927 ushort tiitg[[thread_index_in_threadgroup]],
3928 ushort3 tptg [[threads_per_threadgroup]],
3929 uint3 tgpig[[threadgroup_position_in_grid]]) {
3930 const int i3 = tgpig[2];
3931 const int i2 = tgpig[1];
3932 const int i1 = tgpig[0];
3933
3934 float corr_dims[2];
3935 rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
3936
3937 device const int32_t * pos = (device const int32_t *) src1;
3938
3939 const float theta_base = (float) pos[i2];
3940 const float inv_ndims = -1.f/args.n_dims;
3941
3942 float cos_theta;
3943 float sin_theta;
3944
3945 for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
3946 if (i0 < args.n_dims) {
3947 const int ic = i0/2;
3948
3949 const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
3950
3951 const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
3952
3953 rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
3954
3955 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
3956 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
3957
3958 const float x0 = src[0];
3959 const float x1 = src[args.n_dims/2];
3960
3961 dst_data[0] = x0*cos_theta - x1*sin_theta;
3962 dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
3963 } else {
3964 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
3965 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
3966
3967 dst_data[0] = src[0];
3968 dst_data[1] = src[1];
3969 }
3970 }
3971}
3972
3973template<typename T>
3974kernel void kernel_rope_multi(
3975 constant ggml_metal_kargs_rope & args,
3976 device const char * src0,
3977 device const char * src1,
3978 device const char * src2,
3979 device char * dst,
3980 ushort tiitg[[thread_index_in_threadgroup]],
3981 ushort3 tptg [[threads_per_threadgroup]],
3982 uint3 tgpig[[threadgroup_position_in_grid]]) {
3983 const int i3 = tgpig[2];
3984 const int i2 = tgpig[1];
3985 const int i1 = tgpig[0];
3986
3987 float corr_dims[2];
3988 rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
3989
3990 device const int32_t * pos = (device const int32_t *) src1;
3991
3992 const float inv_ndims = -1.f/args.n_dims;
3993
3994 float cos_theta;
3995 float sin_theta;
3996
3997 for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
3998 if (i0 < args.n_dims) {
3999 const int ic = i0/2;
4000
4001 // mrope theta calculations
4002 // note: the rest is the same as kernel_rope_neox
4003 const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
4004 const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
4005 const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
4006 const int sector = ic % sect_dims;
4007
4008 float theta_base;
4009 if (FC_rope_is_imrope) {
4010 if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h
4011 theta_base = (float) pos[i2 + args.ne02 * 1];
4012 } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w
4013 theta_base = (float) pos[i2 + args.ne02 * 2];
4014 } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t
4015 theta_base = (float) pos[i2 + args.ne02 * 0];
4016 } else { // e
4017 theta_base = (float) pos[i2 + args.ne02 * 3];
4018 }
4019 } else {
4020 if (sector < args.sect_0) {
4021 theta_base = (float) pos[i2];
4022 } else if (sector < sec_w01) {
4023 theta_base = (float) pos[i2 + args.ne02 * 1];
4024 } else if (sector < sec_w012) {
4025 theta_base = (float) pos[i2 + args.ne02 * 2];
4026 } else {
4027 theta_base = (float) pos[i2 + args.ne02 * 3];
4028 }
4029 }
4030 // end of mrope
4031
4032 const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
4033
4034 const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
4035
4036 rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
4037
4038 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
4039 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
4040
4041 const float x0 = src[0];
4042 const float x1 = src[args.n_dims/2];
4043
4044 dst_data[0] = x0*cos_theta - x1*sin_theta;
4045 dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
4046 } else {
4047 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
4048 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4049
4050 dst_data[0] = src[0];
4051 dst_data[1] = src[1];
4052 }
4053 }
4054}
4055
4056template<typename T>
4057kernel void kernel_rope_vision(
4058 constant ggml_metal_kargs_rope & args,
4059 device const char * src0,
4060 device const char * src1,
4061 device const char * src2,
4062 device char * dst,
4063 ushort tiitg[[thread_index_in_threadgroup]],
4064 ushort3 tptg [[threads_per_threadgroup]],
4065 uint3 tgpig[[threadgroup_position_in_grid]]) {
4066 const int i3 = tgpig[2];
4067 const int i2 = tgpig[1];
4068 const int i1 = tgpig[0];
4069
4070 float corr_dims[2];
4071 rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
4072
4073 device const int32_t * pos = (device const int32_t *) src1;
4074
4075 const float inv_ndims = -1.f/args.n_dims;
4076
4077 float cos_theta;
4078 float sin_theta;
4079
4080 for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
4081 if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
4082 const int ic = i0/2;
4083
4084 // mrope theta calculations (only support 2 dimensions)
4085 const int sect_dims = args.sect_0 + args.sect_1;
4086 const int sector = ic % sect_dims;
4087
4088 float p;
4089 float theta_base;
4090 if (sector < args.sect_1) {
4091 p = (float) sector;
4092 theta_base = (float) pos[i2];
4093 } else {
4094 p = (float) sector - args.sect_0;
4095 theta_base = (float) pos[i2 + args.ne02];
4096 }
4097
4098 const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
4099 // end of mrope
4100
4101 const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f;
4102
4103 rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
4104
4105 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
4106 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
4107
4108 const float x0 = src[0];
4109 const float x1 = src[args.n_dims]; // different from kernel_rope_multi
4110
4111 dst_data[0] = x0*cos_theta - x1*sin_theta;
4112 dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
4113 } else {
4114 device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
4115 device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4116
4117 dst_data[0] = src[0];
4118 dst_data[1] = src[1];
4119 }
4120 }
4121}
4122
4123typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
4124typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
4125typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
4126typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
4127
4128template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
4129template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
4130
4131template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
4132template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
4133
4134template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
4135template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
4136
4137template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
4138template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
4139
4140typedef void (im2col_t)(
4141 constant ggml_metal_kargs_im2col & args,
4142 device const float * x,
4143 device char * dst,
4144 uint3 tgpig[[threadgroup_position_in_grid]],
4145 uint3 tgpg[[threadgroups_per_grid]],
4146 uint3 tpitg[[thread_position_in_threadgroup]],
4147 uint3 ntg[[threads_per_threadgroup]]);
4148
4149template <typename T>
4150kernel void kernel_im2col(
4151 constant ggml_metal_kargs_im2col & args,
4152 device const float * x,
4153 device char * dst,
4154 uint3 tgpig[[threadgroup_position_in_grid]],
4155 uint3 tgpg[[threadgroups_per_grid]],
4156 uint3 tpitg[[thread_position_in_threadgroup]],
4157 uint3 ntg[[threads_per_threadgroup]]) {
4158// const int64_t IC = tgpg[0];
4159 const int64_t OH = tgpg[1];
4160 const int64_t OW = tgpg[2];
4161
4162 const int64_t KH = ntg[1];
4163 const int64_t KW = ntg[2];
4164
4165 int64_t in = tpitg[0];
4166 const int64_t ikh = tpitg[1];
4167 const int64_t ikw = tpitg[2];
4168
4169 const int64_t iic = tgpig[0];
4170 const int64_t ioh = tgpig[1];
4171 const int64_t iow = tgpig[2];
4172
4173 const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
4174 const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
4175
4176 int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
4177
4178 device T * pdst = (device T *) (dst);
4179
4180 if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4181 while (in < args.N) {
4182 pdst[offset_dst] = 0.0f;
4183 offset_dst += ntg[0]*args.CHW*OH*OW;
4184
4185 in += ntg[0];
4186 }
4187 } else {
4188 int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
4189
4190 while (in < args.N) {
4191 pdst[offset_dst] = x[offset_src];
4192
4193 offset_dst += ntg[0]*args.CHW*OH*OW;
4194 offset_src += ntg[0]*args.ofs0;
4195
4196 in += ntg[0];
4197 }
4198 }
4199}
4200
4201template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
4202template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
4203
4204// TODO: obolete -- remove
4205//typedef void (im2col_ext_t)(
4206// constant ggml_metal_kargs_im2col & args,
4207// device const float * x,
4208// device char * dst,
4209// uint3 tgpig[[threadgroup_position_in_grid]],
4210// uint3 tgpg[[threadgroups_per_grid]],
4211// uint3 tpitg[[thread_position_in_threadgroup]],
4212// uint3 ntg[[threads_per_threadgroup]]);
4213//
4214//template <typename T>
4215//kernel void kernel_im2col_ext(
4216// constant ggml_metal_kargs_im2col & args,
4217// device const float * x,
4218// device char * dst,
4219// uint3 tgpig[[threadgroup_position_in_grid]],
4220// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
4221// uint3 tpitg[[thread_position_in_threadgroup]],
4222// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
4223// const int64_t KHW = (int64_t)args.KHW;
4224//
4225// const int64_t d = tgpig[0] / args.CHW;
4226// const int64_t chw = tgpig[0] % args.CHW;
4227// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
4228// const int64_t HW = tgpig[0] % KHW;
4229//
4230// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
4231// if (tpitg_0 >= args.N) {
4232// return;
4233// }
4234//
4235// const int64_t tpitg_1 = HW / args.KW;
4236// const int64_t tpitg_2 = HW % args.KW;
4237//
4238// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
4239// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
4240//
4241// const int64_t offset_dst =
4242// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
4243// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
4244//
4245// device T * pdst = (device T *) (dst);
4246//
4247// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
4248// pdst[offset_dst] = 0.0f;
4249// } else {
4250// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
4251// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
4252// }
4253//}
4254//
4255//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
4256//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
4257
4258template <typename TK>
4259kernel void kernel_conv_2d(
4260 constant ggml_metal_kargs_conv_2d & args,
4261 device const char * weights,
4262 device const char * src,
4263 device char * dst,
4264 uint3 tgpig[[threadgroup_position_in_grid]],
4265 uint3 tgpg[[threadgroups_per_grid]],
4266 uint3 tpitg[[thread_position_in_threadgroup]],
4267 uint3 ntg[[threads_per_threadgroup]]) {
4268
4269 const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
4270 const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
4271 const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
4272 const uint thread_index = tg_index * threads_per_tg + local_thread;
4273 const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
4274 const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
4275
4276 for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
4277 uint64_t tmp = index;
4278
4279 const int32_t ow = tmp % args.OW; tmp /= args.OW;
4280 const int32_t oh = tmp % args.OH; tmp /= args.OH;
4281 const int32_t oc = tmp % args.OC; tmp /= args.OC;
4282 const int32_t n = tmp;
4283
4284 float acc = 0.0f;
4285
4286 const int32_t base_x = ow*args.s0 - args.p0;
4287 const int32_t base_y = oh*args.s1 - args.p1;
4288
4289 int32_t ky_start = 0;
4290 if (base_y < 0) {
4291 ky_start = (-base_y + args.d1 - 1)/args.d1;
4292 }
4293 int32_t ky_end = args.KH;
4294 const int32_t y_max = args.IH - 1 - base_y;
4295 if (y_max < 0) {
4296 ky_end = ky_start;
4297 } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
4298 ky_end = min(ky_end, y_max/args.d1 + 1);
4299 }
4300
4301 int32_t kx_start = 0;
4302 if (base_x < 0) {
4303 kx_start = (-base_x + args.d0 - 1)/args.d0;
4304 }
4305 int32_t kx_end = args.KW;
4306 const int32_t x_max = args.IW - 1 - base_x;
4307 if (x_max < 0) {
4308 kx_end = kx_start;
4309 } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
4310 kx_end = min(kx_end, x_max/args.d0 + 1);
4311 }
4312
4313 if (ky_start < ky_end && kx_start < kx_end) {
4314 const uint64_t src_base_n = (uint64_t) n * args.nb13;
4315 const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
4316
4317 for (int32_t ic = 0; ic < args.IC; ++ic) {
4318 const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
4319 const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
4320
4321 for (int32_t ky = ky_start; ky < ky_end; ++ky) {
4322 const int32_t iy = base_y + ky*args.d1;
4323 const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
4324 const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
4325
4326 for (int32_t kx = kx_start; kx < kx_end; ++kx) {
4327 const int32_t ix = base_x + kx*args.d0;
4328 const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
4329 const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
4330
4331 const float x = *(device const float *)(src + src_offs);
4332 const float w = (float) (*(device const TK *)(weights + w_offs));
4333
4334 acc += x * w;
4335 }
4336 }
4337 }
4338 }
4339
4340 const uint64_t dst_offs =
4341 (uint64_t) n * args.nb3 +
4342 (uint64_t) oc * args.nb2 +
4343 (uint64_t) oh * args.nb1 +
4344 (uint64_t) ow * args.nb0;
4345
4346 *(device float *)(dst + dst_offs) = acc;
4347 }
4348}
4349
4350template [[host_name("kernel_conv_2d_f32_f32")]]
4351kernel void kernel_conv_2d<float>(
4352 constant ggml_metal_kargs_conv_2d & args,
4353 device const char * weights,
4354 device const char * src,
4355 device char * dst,
4356 uint3 tgpig[[threadgroup_position_in_grid]],
4357 uint3 tgpg[[threadgroups_per_grid]],
4358 uint3 tpitg[[thread_position_in_threadgroup]],
4359 uint3 ntg[[threads_per_threadgroup]]);
4360
4361template [[host_name("kernel_conv_2d_f16_f32")]]
4362kernel void kernel_conv_2d<half>(
4363 constant ggml_metal_kargs_conv_2d & args,
4364 device const char * weights,
4365 device const char * src,
4366 device char * dst,
4367 uint3 tgpig[[threadgroup_position_in_grid]],
4368 uint3 tgpg[[threadgroups_per_grid]],
4369 uint3 tpitg[[thread_position_in_threadgroup]],
4370 uint3 ntg[[threads_per_threadgroup]]);
4371
4372typedef void (conv_transpose_1d_t)(
4373 constant ggml_metal_kargs_conv_transpose_1d & args,
4374 device const float * src0,
4375 device const float * src1,
4376 device char * dst,
4377 uint3 tgpig[[threadgroup_position_in_grid]],
4378 uint3 tgpg[[threadgroups_per_grid]]);
4379
4380template <typename T>
4381kernel void kernel_conv_transpose_1d(
4382 constant ggml_metal_kargs_conv_transpose_1d & args,
4383 device const T * src0,
4384 device const float * src1,
4385 device char * dst,
4386 uint3 tgpig[[threadgroup_position_in_grid]],
4387 uint3 tgpg[[threadgroups_per_grid]]) {
4388
4389 float v = 0.0f;
4390
4391 for (int64_t c = 0; c < args.IC; c++) {
4392 const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
4393 const int32_t input_offset = c * args.IL;
4394
4395 for (int64_t i = 0; i < args.IL; i++) {
4396 if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
4397 v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
4398 }
4399 }
4400 }
4401
4402 device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
4403
4404 dst_ptr[0] = v;
4405}
4406
4407template [[host_name("kernel_conv_transpose_1d_f32_f32")]]
4408kernel void kernel_conv_transpose_1d<float>(
4409 constant ggml_metal_kargs_conv_transpose_1d & args,
4410 device const float * src0,
4411 device const float * src1,
4412 device char * dst,
4413 uint3 tgpig[[threadgroup_position_in_grid]],
4414 uint3 tgpg[[threadgroups_per_grid]]);
4415
4416template [[host_name("kernel_conv_transpose_1d_f16_f32")]]
4417kernel void kernel_conv_transpose_1d<half>(
4418 constant ggml_metal_kargs_conv_transpose_1d & args,
4419 device const half * src0,
4420 device const float * src1,
4421 device char * dst,
4422 uint3 tgpig[[threadgroup_position_in_grid]],
4423 uint3 tgpg[[threadgroups_per_grid]]);
4424
4425
4426typedef void (conv_transpose_2d_t)(
4427 constant ggml_metal_kargs_conv_transpose_2d & args,
4428 device const float * src0,
4429 device const float * src1,
4430 device char * dst,
4431 uint3 tgpig[[threadgroup_position_in_grid]],
4432 uint3 tgpg[[threadgroups_per_grid]]);
4433
4434template <typename T>
4435kernel void kernel_conv_transpose_2d(
4436 constant ggml_metal_kargs_conv_transpose_2d & args,
4437 device const T * src0,
4438 device const float * src1,
4439 device char * dst,
4440 threadgroup float * shared_sum [[threadgroup(0)]],
4441 uint3 tgpig[[threadgroup_position_in_grid]],
4442 uint3 tpitg[[thread_position_in_threadgroup]],
4443 uint3 ntg[[threads_per_threadgroup]]) {
4444
4445 const int64_t out_x = tgpig[0];
4446 const int64_t out_y = tgpig[1];
4447 const int64_t out_c = tgpig[2];
4448
4449 const int64_t kw = tpitg[0];
4450 const int64_t kh = tpitg[1];
4451
4452 float v = 0.0f;
4453
4454 for (int64_t in_c = 0; in_c < args.IC; in_c++) {
4455 int64_t in_y = out_y - kh;
4456
4457 if (in_y < 0 || in_y % args.s0) continue;
4458
4459 in_y /= args.s0;
4460
4461 if (in_y >= args.IH) continue;
4462
4463 int64_t in_x = out_x - kw;
4464
4465 if (in_x < 0 || in_x % args.s0) continue;
4466
4467 in_x /= args.s0;
4468
4469 if (in_x >= args.IW) continue;
4470
4471 const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
4472 const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
4473
4474 v += (float)src0[kernel_idx] * src1[input_idx];
4475 }
4476
4477 const uint tid = tpitg.y * ntg.x + tpitg.x;
4478 shared_sum[tid] = v;
4479
4480 threadgroup_barrier(mem_flags::mem_threadgroup);
4481
4482 if (tid == 0) {
4483 float total = 0.0f;
4484 const uint num_threads = ntg.x * ntg.y;
4485 for (uint i = 0; i < num_threads; i++) {
4486 total += shared_sum[i];
4487 }
4488
4489 device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
4490 dst_ptr[0] = total;
4491 }
4492}
4493
4494template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
4495kernel void kernel_conv_transpose_2d<float>(
4496 constant ggml_metal_kargs_conv_transpose_2d & args,
4497 device const float * src0,
4498 device const float * src1,
4499 device char * dst,
4500 threadgroup float * shared_sum [[threadgroup(0)]],
4501 uint3 tgpig[[threadgroup_position_in_grid]],
4502 uint3 tpitg[[thread_position_in_threadgroup]],
4503 uint3 ntg[[threads_per_threadgroup]]);
4504
4505template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
4506kernel void kernel_conv_transpose_2d<half>(
4507 constant ggml_metal_kargs_conv_transpose_2d & args,
4508 device const half * src0,
4509 device const float * src1,
4510 device char * dst,
4511 threadgroup float * shared_sum [[threadgroup(0)]],
4512 uint3 tgpig[[threadgroup_position_in_grid]],
4513 uint3 tpitg[[thread_position_in_threadgroup]],
4514 uint3 ntg[[threads_per_threadgroup]]);
4515
4516kernel void kernel_upscale_f32(
4517 constant ggml_metal_kargs_upscale & args,
4518 device const char * src0,
4519 device char * dst,
4520 uint3 tgpig[[threadgroup_position_in_grid]],
4521 uint3 tpitg[[thread_position_in_threadgroup]],
4522 uint3 ntg[[threads_per_threadgroup]]) {
4523
4524 const int64_t i3 = tgpig.z;
4525 const int64_t i2 = tgpig.y;
4526 const int64_t i1 = tgpig.x;
4527
4528 const int64_t i03 = i3/args.sf3;
4529 const int64_t i02 = i2/args.sf2;
4530 const int64_t i01 = i1/args.sf1;
4531
4532 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4533 const int64_t i00 = i0/args.sf0;
4534
4535 device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
4536 device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
4537
4538 dst_ptr[0] = src0_ptr[0];
4539 }
4540}
4541
4542kernel void kernel_pad_f32(
4543 constant ggml_metal_kargs_pad & args,
4544 device const char * src0,
4545 device char * dst,
4546 uint3 tgpig[[threadgroup_position_in_grid]],
4547 uint3 tpitg[[thread_position_in_threadgroup]],
4548 uint3 ntg[[threads_per_threadgroup]]) {
4549
4550 const int64_t i3 = tgpig.z;
4551 const int64_t i2 = tgpig.y;
4552 const int64_t i1 = tgpig.x;
4553
4554 const int64_t i03 = i3;
4555 const int64_t i02 = i2;
4556 const int64_t i01 = i1;
4557
4558 device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4559 device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
4560
4561 if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
4562 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4563 if (i0 < args.ne00) {
4564 dst_ptr[i0] = src0_ptr[i0];
4565 } else {
4566 dst_ptr[i0] = 0.0f;
4567 }
4568 }
4569
4570 return;
4571 }
4572
4573 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4574 dst_ptr[i0] = 0.0f;
4575 }
4576}
4577
4578kernel void kernel_pad_reflect_1d_f32(
4579 constant ggml_metal_kargs_pad_reflect_1d & args,
4580 device const char * src0,
4581 device char * dst,
4582 uint3 tgpig[[threadgroup_position_in_grid]],
4583 uint3 tgpg[[threadgroups_per_grid]],
4584 uint3 tpitg[[thread_position_in_threadgroup]],
4585 uint3 ntg[[threads_per_threadgroup]]) {
4586
4587 const int64_t i3 = tgpig.z;
4588 const int64_t i2 = tgpig.y;
4589 const int64_t i1 = tgpig.x;
4590
4591 const int64_t i03 = i3;
4592 const int64_t i02 = i2;
4593 const int64_t i01 = i1;
4594
4595 device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
4596 device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1);
4597
4598 if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
4599 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4600 if (i0 < args.p0) {
4601 dst_ptr[i0] = src0_ptr[args.p0 - i0];
4602 } else if (i0 < args.ne0 - args.p1) {
4603 dst_ptr[i0] = src0_ptr[i0 - args.p0];
4604 } else {
4605 dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
4606 }
4607 }
4608 }
4609}
4610
4611kernel void kernel_arange_f32(
4612 constant ggml_metal_kargs_arange & args,
4613 device char * dst,
4614 uint3 tgpig[[threadgroup_position_in_grid]],
4615 uint3 tpitg[[thread_position_in_threadgroup]],
4616 uint3 ntg[[threads_per_threadgroup]]) {
4617
4618 device float * dst_ptr = (device float *) dst;
4619
4620 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
4621 dst_ptr[i0] = args.start + args.step * i0;
4622 }
4623}
4624
4625kernel void kernel_timestep_embedding_f32(
4626 constant ggml_metal_kargs_timestep_embedding & args,
4627 device const char * src0,
4628 device char * dst,
4629 uint3 tgpig[[threadgroup_position_in_grid]],
4630 uint3 tpitg[[thread_position_in_threadgroup]],
4631 uint3 ntg[[threads_per_threadgroup]]) {
4632
4633 int i = tgpig.x;
4634 device float * embed_data = (device float *)(dst + i*args.nb1);
4635
4636 int half_ = args.dim / 2;
4637 for (int j = tpitg.x; j < half_; j += ntg.x) {
4638 float timestep = ((device float *)src0)[i];
4639 float freq = (float)exp(-log((float)args.max_period) * j / half_);
4640 float arg = timestep * freq;
4641 embed_data[j ] = cos(arg);
4642 embed_data[j + half_] = sin(arg);
4643 }
4644
4645 if (args.dim % 2 != 0 && tpitg.x == 0) {
4646 embed_data[2 * half_] = 0.f;
4647 }
4648}
4649
4650// bitonic sort implementation following the CUDA kernels as reference
4651typedef void (argsort_t)(
4652 constant ggml_metal_kargs_argsort & args,
4653 device const char * src0,
4654 device int32_t * dst,
4655 threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
4656 uint3 tgpig[[threadgroup_position_in_grid]],
4657 ushort3 tpitg[[thread_position_in_threadgroup]],
4658 ushort3 ntg[[threads_per_threadgroup]]);
4659
4660template<ggml_sort_order order>
4661kernel void kernel_argsort_f32_i32(
4662 constant ggml_metal_kargs_argsort & args,
4663 device const char * src0,
4664 device int32_t * dst,
4665 threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
4666 uint3 tgpig[[threadgroup_position_in_grid]],
4667 ushort3 tpitg[[thread_position_in_threadgroup]],
4668 ushort3 ntg[[threads_per_threadgroup]]) {
4669 // bitonic sort
4670 const int col = tpitg[0];
4671 const int ib = tgpig[0] / args.ne01;
4672
4673 const int i00 = ib*ntg.x;
4674 const int i01 = tgpig[0] % args.ne01;
4675 const int i02 = tgpig[1];
4676 const int i03 = tgpig[2];
4677
4678 device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
4679
4680 // initialize indices
4681 shmem_i32[col] = i00 + col;
4682
4683 threadgroup_barrier(mem_flags::mem_threadgroup);
4684
4685 for (int k = 2; k <= ntg.x; k *= 2) {
4686 for (int j = k / 2; j > 0; j /= 2) {
4687 int ixj = col ^ j;
4688 if (ixj > col) {
4689 if ((col & k) == 0) {
4690 if (shmem_i32[col] >= args.ne00 ||
4691 (shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
4692 src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] :
4693 src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]]))
4694 ) {
4695 SWAP(shmem_i32[col], shmem_i32[ixj]);
4696 }
4697 } else {
4698 if (shmem_i32[ixj] >= args.ne00 ||
4699 (shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
4700 src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] :
4701 src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]]))
4702 ) {
4703 SWAP(shmem_i32[col], shmem_i32[ixj]);
4704 }
4705 }
4706 }
4707
4708 threadgroup_barrier(mem_flags::mem_threadgroup);
4709 }
4710 }
4711
4712 const int64_t i0 = ib*args.top_k;
4713
4714 // copy the result to dst without the padding
4715 if (i0 + col < args.ne0 && col < args.top_k) {
4716 dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03;
4717
4718 dst[col] = shmem_i32[col];
4719 }
4720}
4721
4722template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
4723template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
4724
4725typedef void (argsort_merge_t)(
4726 constant ggml_metal_kargs_argsort_merge & args,
4727 device const char * src0,
4728 device const int32_t * tmp,
4729 device int32_t * dst,
4730 uint3 tgpig[[threadgroup_position_in_grid]],
4731 ushort3 tpitg[[thread_position_in_threadgroup]],
4732 ushort3 ntg[[threads_per_threadgroup]]);
4733
4734template<ggml_sort_order order>
4735kernel void kernel_argsort_merge_f32_i32(
4736 constant ggml_metal_kargs_argsort_merge & args,
4737 device const char * src0,
4738 device const int32_t * tmp,
4739 device int32_t * dst,
4740 uint3 tgpig[[threadgroup_position_in_grid]],
4741 ushort3 tpitg[[thread_position_in_threadgroup]],
4742 ushort3 ntg[[threads_per_threadgroup]]) {
4743
4744 const int im = tgpig[0] / args.ne01;
4745 const int i01 = tgpig[0] % args.ne01;
4746 const int i02 = tgpig[1];
4747 const int i03 = tgpig[2];
4748
4749 const int start = im * (2 * args.len);
4750
4751 const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start)));
4752 const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len)));
4753
4754 const int total = len0 + len1;
4755
4756 device const int32_t * tmp0 = tmp + start
4757 + i01*args.ne0
4758 + i02*args.ne0*args.ne01
4759 + i03*args.ne0*args.ne01*args.ne02;
4760
4761 device const int32_t * tmp1 = tmp0 + args.len;
4762
4763 dst += start
4764 + i01*args.top_k
4765 + i02*args.top_k*args.ne01
4766 + i03*args.top_k*args.ne01*args.ne02;
4767
4768 device const float * src0_row = (device const float *)(src0
4769 + args.nb01*i01
4770 + args.nb02*i02
4771 + args.nb03*i03);
4772
4773 if (total == 0) {
4774 return;
4775 }
4776
4777 const int chunk = (total + ntg.x - 1) / ntg.x;
4778
4779 const int k0 = tpitg.x * chunk;
4780 const int k1 = MIN(MIN(k0 + chunk, total), args.top_k);
4781
4782 if (k0 >= args.top_k) {
4783 return;
4784 }
4785
4786 if (k0 >= total) {
4787 return;
4788 }
4789
4790 int low = k0 > len1 ? k0 - len1 : 0;
4791 int high = MIN(k0, len0);
4792
4793 // binary-search partition (i, j) such that i + j = k
4794 while (low < high) {
4795 const int mid = (low + high) >> 1;
4796
4797 const int32_t idx0 = tmp0[mid];
4798 const int32_t idx1 = tmp1[k0 - mid - 1];
4799
4800 const float val0 = src0_row[idx0];
4801 const float val1 = src0_row[idx1];
4802
4803 bool take_left;
4804 if (order == GGML_SORT_ORDER_ASC) {
4805 take_left = (val0 <= val1);
4806 } else {
4807 take_left = (val0 >= val1);
4808 }
4809
4810 if (take_left) {
4811 low = mid + 1;
4812 } else {
4813 high = mid;
4814 }
4815 }
4816
4817 int i = low;
4818 int j = k0 - i;
4819
4820 // keep the merge fronts into registers
4821 int32_t idx0 = 0;
4822 float val0 = 0.0f;
4823 if (i < len0) {
4824 idx0 = tmp0[i];
4825 val0 = src0_row[idx0];
4826 }
4827
4828 int32_t idx1 = 0;
4829 float val1 = 0.0f;
4830 if (j < len1) {
4831 idx1 = tmp1[j];
4832 val1 = src0_row[idx1];
4833 }
4834
4835 for (int k = k0; k < k1; ++k) {
4836 int32_t out_idx;
4837
4838 if (i >= len0) {
4839 while (k < k1) {
4840 dst[k++] = tmp1[j++];
4841 }
4842 break;
4843 } else if (j >= len1) {
4844 while (k < k1) {
4845 dst[k++] = tmp0[i++];
4846 }
4847 break;
4848 } else {
4849 bool take_left;
4850
4851 if (order == GGML_SORT_ORDER_ASC) {
4852 take_left = (val0 <= val1);
4853 } else {
4854 take_left = (val0 >= val1);
4855 }
4856
4857 if (take_left) {
4858 out_idx = idx0;
4859 ++i;
4860 if (i < len0) {
4861 idx0 = tmp0[i];
4862 val0 = src0_row[idx0];
4863 }
4864 } else {
4865 out_idx = idx1;
4866 ++j;
4867 if (j < len1) {
4868 idx1 = tmp1[j];
4869 val1 = src0_row[idx1];
4870 }
4871 }
4872 }
4873
4874 dst[k] = out_idx;
4875 }
4876}
4877
4878template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
4879template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
4880
4881constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]];
4882
4883constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]];
4884
4885// pad the last chunk of C elements of k and v into a an extra pad buffer
4886kernel void kernel_flash_attn_ext_pad(
4887 constant ggml_metal_kargs_flash_attn_ext_pad & args,
4888 device const char * k,
4889 device const char * v,
4890 device const char * mask,
4891 device char * dst,
4892 uint3 tgpig[[threadgroup_position_in_grid]],
4893 ushort tiitg[[thread_index_in_threadgroup]],
4894 ushort3 ntg[[threads_per_threadgroup]]) {
4895 const int32_t C = FC_flash_attn_ext_pad_ncpsg;
4896
4897 device char * k_pad = dst;
4898 device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3;
4899 device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3;
4900
4901 const int32_t icp = args.ne11 % C;
4902 const int32_t ic0 = args.ne11 - icp;
4903
4904 const int32_t i1 = tgpig[0];
4905 const int32_t i2 = tgpig[1];
4906 const int32_t i3 = tgpig[2];
4907
4908 if (i2 < args.ne_12_2 && i3 < args.ne_12_3) {
4909 device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3;
4910 device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3;
4911
4912 device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3;
4913 device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3;
4914
4915 if (i1 >= icp) {
4916 // here it is not important the exact value that will be used as we rely on masking out the scores in the attention
4917 for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
4918 k_dst[i] = 0;
4919 }
4920 for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
4921 v_dst[i] = 0;
4922 }
4923 } else {
4924 for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) {
4925 k_dst[i] = k_src[i];
4926 }
4927 for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) {
4928 v_dst[i] = v_src[i];
4929 }
4930 }
4931 }
4932
4933 if (FC_flash_attn_ext_pad_has_mask) {
4934 if (i2 < args.ne32 && i3 < args.ne33) {
4935 for (int ib = i1; ib < args.ne31; ib += C) {
4936 device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0;
4937 device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3;
4938
4939 for (int i = tiitg; i < C; i += ntg.x) {
4940 if (i >= icp) {
4941 mask_dst[i] = -MAXHALF;
4942 } else {
4943 mask_dst[i] = mask_src[i];
4944 }
4945 }
4946 }
4947 }
4948 }
4949}
4950
4951constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]];
4952constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]];
4953
4954// scan the blocks of the mask that are not masked
4955// 0 - masked (i.e. full of -INF, skip)
4956// 1 - not masked (i.e. at least one element of the mask is not -INF)
4957// 2 - all zero
4958kernel void kernel_flash_attn_ext_blk(
4959 constant ggml_metal_kargs_flash_attn_ext_blk & args,
4960 device const char * mask,
4961 device char * dst,
4962 uint3 tgpig[[threadgroup_position_in_grid]],
4963 ushort tiisg[[thread_index_in_simdgroup]]) {
4964 // block size C x Q
4965 const int32_t Q = FC_flash_attn_ext_blk_nqptg;
4966 const int32_t C = FC_flash_attn_ext_blk_ncpsg;
4967
4968 constexpr short NW = N_SIMDWIDTH;
4969
4970 const int32_t i3 = tgpig[2]/args.ne32;
4971 const int32_t i2 = tgpig[2]%args.ne32;
4972 const int32_t i1 = tgpig[1];
4973 const int32_t i0 = tgpig[0];
4974
4975 char res = i0*C + C > args.ne30 ? 1 : 0;
4976
4977 device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg;
4978
4979 // detailed check of the elements of the block
4980 if ((C > NW || Q > 1) && res == 0) {
4981 half mmin = MAXHALF;
4982 half mmax = -MAXHALF;
4983
4984 FOR_UNROLL (short j = 0; j < Q; ++j) {
4985 FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) {
4986 mmin = min(mmin, mask_src[ii*NW]);
4987 mmax = max(mmax, mask_src[ii*NW]);
4988 }
4989
4990 mask_src += args.nb31/2;
4991 }
4992
4993 mmin = simd_min(mmin);
4994 mmax = simd_max(mmax);
4995
4996 if (mmax > -MAXHALF) {
4997 if (mmin == 0.0 && mmax == 0.0) {
4998 res = 2;
4999 } else {
5000 res = 1;
5001 }
5002 }
5003 }
5004
5005 const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
5006 const int32_t nblk0 = ((args.ne30 + C - 1)/C);
5007
5008 if (tiisg == 0) {
5009 dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res;
5010 }
5011}
5012
5013constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]];
5014constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]];
5015constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]];
5016constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]];
5017constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]];
5018
5019constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
5020
5021//constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]];
5022//constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]];
5023//constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]];
5024
5025constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]];
5026constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]];
5027constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]];
5028
5029// ref: https://arxiv.org/pdf/2307.08691.pdf
5030template<
5031 typename q_t, // query types in shared memory
5032 typename q4_t,
5033 typename q8x8_t,
5034 typename k_t, // key types in shared memory
5035 typename k4x4_t,
5036 typename k8x8_t,
5037 typename v_t, // value types in shared memory
5038 typename v4x4_t,
5039 typename v8x8_t,
5040 typename qk_t, // Q*K types
5041 typename qk8x8_t,
5042 typename s_t, // soft-max types
5043 typename s2_t,
5044 typename s8x8_t,
5045 typename o_t, // attention accumulation types
5046 typename o4_t,
5047 typename o8x8_t,
5048 typename kd4x4_t, // key type in device memory
5049 short nl_k,
5050 void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
5051 typename vd4x4_t, // value type in device memory
5052 short nl_v,
5053 void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
5054 short DK, // K head size
5055 short DV, // V head size
5056 short Q, // queries per threadgroup
5057 short C, // cache items per threadgroup
5058 short NSG> // number of simd groups
5059void kernel_flash_attn_ext_impl(
5060 constant ggml_metal_kargs_flash_attn_ext & args,
5061 device const char * q,
5062 device const char * k,
5063 device const char * v,
5064 device const char * mask,
5065 device const char * sinks,
5066 device const char * pad,
5067 device const char * blk,
5068 device char * dst,
5069 threadgroup half * shmem_f16,
5070 uint3 tgpig,
5071 ushort tiisg,
5072 ushort sgitg) {
5073 const ushort iq3 = tgpig[2];
5074 const ushort iq2 = tgpig[1];
5075 const ushort iq1 = tgpig[0]*Q;
5076
5077#define NS10 (FC_flash_attn_ext_ns10)
5078#define NS20 (FC_flash_attn_ext_ns20)
5079
5080 // note: I had some concerns that using this instead of the ugly macros above was affecting performance
5081 // need to re-check carefully and if no regressions are observerd - remove the macros
5082 // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler
5083 // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC
5084 //const short NS10 = FC_flash_attn_ext_ns10;
5085 //const short NS20 = FC_flash_attn_ext_ns20;
5086
5087 constexpr short KV = 8;
5088
5089 constexpr short DK4 = DK/4;
5090 constexpr short DK8 = DK/8;
5091 constexpr short DK16 = DK/16;
5092 constexpr short DV4 = DV/4;
5093 //constexpr short DV8 = DV/8;
5094 constexpr short DV16 = DV/16;
5095
5096 constexpr short PV = PAD2(DV, 64);
5097 constexpr short PV4 = PV/4;
5098 constexpr short PV8 = PV/8;
5099 //constexpr short PV16 = PV/16;
5100
5101 constexpr short NW = N_SIMDWIDTH;
5102 constexpr short NQ = Q/NSG;
5103 constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float)
5104
5105 constexpr short TS = 2*SH;
5106 constexpr short T = DK + 2*PV; // shared memory size per query in (half)
5107
5108 threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data
5109 threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t
5110 threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper)
5111 threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK);
5112 threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix
5113 threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t
5114
5115 threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory
5116 threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t
5117
5118 threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory
5119 threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t
5120
5121 // mask storage in shared mem
5122 threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
5123
5124 // per-query mask pointers
5125 device const half2 * pm2[NQ];
5126
5127 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5128 const short j = jj*NSG + sgitg;
5129
5130 pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
5131 }
5132
5133 {
5134 const int32_t nblk1 = ((args.ne01 + Q - 1)/Q);
5135 const int32_t nblk0 = ((args.ne11 + C - 1)/C);
5136
5137 blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0;
5138 }
5139
5140 {
5141 q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
5142
5143 const short ikv2 = iq2/(args.ne02/args.ne_12_2);
5144 const short ikv3 = iq3/(args.ne03/args.ne_12_3);
5145
5146 k += ikv2*args.nb12 + ikv3*args.nb13;
5147 v += ikv2*args.nb22 + ikv3*args.nb23;
5148 }
5149
5150 // load heads from Q to shared memory
5151 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5152 const short j = jj*NSG + sgitg;
5153
5154 device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01);
5155
5156 for (short i = tiisg; i < DK4; i += NW) {
5157 if (iq1 + j < args.ne01) {
5158 sq4[j*DK4 + i] = (q4_t) q4[i];
5159 } else {
5160 sq4[j*DK4 + i] = 0;
5161 }
5162 }
5163 }
5164
5165 // zero out
5166 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5167 const short j = jj*NSG + sgitg;
5168
5169 for (short i = tiisg; i < DV4; i += NW) {
5170 so4[j*PV4 + i] = 0;
5171 }
5172
5173 for (short i = tiisg; i < SH; i += NW) {
5174 ss[j*SH + i] = 0.0f;
5175 }
5176 }
5177
5178 threadgroup_barrier(mem_flags::mem_threadgroup);
5179
5180 float S[NQ] = { [0 ... NQ-1] = 0.0f };
5181
5182 {
5183 float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 };
5184
5185 float slope = 1.0f;
5186
5187 // ALiBi
5188 if (FC_flash_attn_ext_has_bias) {
5189 const short h = iq2;
5190
5191 const float base = h < args.n_head_log2 ? args.m0 : args.m1;
5192 const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
5193
5194 slope = pow(base, exph);
5195 }
5196
5197 // loop over the KV cache
5198 // each simdgroup handles blocks of Q rows and C columns
5199 for (int ic0 = 0; ; ++ic0) {
5200 int ic = ic0*C;
5201 if (ic >= args.ne11) {
5202 break;
5203 }
5204
5205 // the last partial chunk uses the pad buffer as source
5206 if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) {
5207 k = pad;
5208 v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
5209 mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
5210
5211 const short ikv2 = iq2/(args.ne02/args.ne_12_2);
5212 const short ikv3 = iq3/(args.ne03/args.ne_12_3);
5213
5214 k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
5215 v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
5216
5217 if (!FC_flash_attn_ext_has_mask) {
5218 threadgroup half * sm = (threadgroup half *) (sm2);
5219
5220 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5221 const short j = jj*NSG + sgitg;
5222
5223 for (short i = tiisg; i < C; i += NW) {
5224 if (ic + i >= args.ne11) {
5225 sm[2*j*SH + i] = -MAXHALF;
5226 }
5227 }
5228 }
5229 } else {
5230 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5231 const short j = jj*NSG + sgitg;
5232
5233 pm2[jj] = (device const half2 *) ((device const half *) mask +
5234 (iq1 + j)*C +
5235 (iq2%args.ne32)*(C*args.ne31) +
5236 (iq3%args.ne33)*(C*args.ne31*args.ne32));
5237 }
5238 }
5239
5240 ic = 0;
5241 }
5242
5243 char blk_cur = 1;
5244
5245 // read the mask into shared mem
5246 if (FC_flash_attn_ext_has_mask) {
5247 blk_cur = blk[ic0];
5248
5249 if (blk_cur == 0) {
5250 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5251 pm2[jj] += NW;
5252 }
5253
5254 continue;
5255 }
5256
5257 if (blk_cur == 1) {
5258 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5259 const short j = jj*NSG + sgitg;
5260
5261 if (FC_flash_attn_ext_bc_mask) {
5262 sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF);
5263 } else {
5264 sm2[j*SH + tiisg] = pm2[jj][tiisg];
5265 }
5266
5267 pm2[jj] += NW;
5268 }
5269 } else if (blk_cur == 2) {
5270 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5271 pm2[jj] += NW;
5272 }
5273 }
5274
5275#if 0
5276 // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks
5277
5278 threadgroup_barrier(mem_flags::mem_threadgroup);
5279
5280 // used to detect blocks full of -INF
5281 // skip only when the entire threadgroup is masked
5282 half2 smax2(-MAXHALF/2, -MAXHALF/2);
5283
5284 FOR_UNROLL (short j = 0; j < Q; ++j) {
5285 smax2 = max(smax2, sm2[j*SH + tiisg]);
5286 }
5287
5288 smax2 = simd_max(smax2);
5289
5290 if (max(smax2[0], smax2[1]) <= -MAXHALF/2) {
5291 // this barrier is important
5292 threadgroup_barrier(mem_flags::mem_threadgroup);
5293
5294 continue;
5295 }
5296#endif
5297 }
5298
5299 // Q*K^T
5300 // this is compile-time check, so it does not have runtime overhead
5301 if (is_same<kd4x4_t, k4x4_t>::value) {
5302 // we can read directly from global memory
5303 device const k_t * pk = (device const k_t *) (k + ic*args.nb11);
5304 threadgroup const q_t * pq = sq;
5305 threadgroup s_t * ps = ss;
5306
5307 pk += sgitg*(8*NS10);
5308 ps += sgitg*(8*1);
5309
5310 static_assert((C/8) % NSG == 0, "");
5311
5312 constexpr short NC = (C/8)/NSG;
5313
5314 FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5315 qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
5316
5317 if (DK % 16 != 0) {
5318 k8x8_t mk;
5319 q8x8_t mq;
5320
5321 FOR_UNROLL (short i = 0; i < DK8; ++i) {
5322 simdgroup_barrier(mem_flags::mem_none);
5323
5324 simdgroup_load(mk, pk + 8*i, NS10, 0, true);
5325 simdgroup_load(mq, pq + 8*i, DK);
5326
5327 simdgroup_barrier(mem_flags::mem_none);
5328
5329 simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5330 }
5331 } else {
5332 k8x8_t mk[2];
5333 q8x8_t mq[2];
5334
5335 // note: too much unroll can tank the performance for large heads
5336 #pragma unroll (MIN(DK8/2, 4*NSG))
5337 for (short i = 0; i < DK8/2; ++i) {
5338 simdgroup_barrier(mem_flags::mem_none);
5339
5340 simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
5341 simdgroup_load(mq[1], pq + 1*8 + 16*i, DK);
5342
5343 simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true);
5344 simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true);
5345
5346 simdgroup_barrier(mem_flags::mem_none);
5347
5348 simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk);
5349 simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk);
5350 }
5351 }
5352
5353 simdgroup_store(mqk, ps, SH, 0, false);
5354
5355 pk += 8*(NSG*NS10);
5356 ps += 8*(NSG);
5357 }
5358 } else {
5359 // TODO: this is the quantized K cache branch - not optimized yet
5360 for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) {
5361 const short cc = ccc*NSG + sgitg;
5362
5363 const short tx = tiisg%4;
5364 const short ty = tiisg/4;
5365
5366 qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
5367
5368 for (short ii = 0; ii < DK16; ii += 4) {
5369 device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11));
5370
5371 if (DK16%4 == 0) {
5372 // the head is evenly divisible by 4*16 = 64, so no need for bound checks
5373 {
5374 k4x4_t tmp;
5375 deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
5376 sk4x4[4*ty + tx] = tmp;
5377 }
5378
5379 simdgroup_barrier(mem_flags::mem_threadgroup);
5380
5381 FOR_UNROLL (short k = 0; k < 4; ++k) {
5382 k8x8_t mk;
5383 q8x8_t mq;
5384
5385 simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
5386 simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
5387 simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5388
5389 simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
5390 simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
5391 simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5392 }
5393 } else {
5394 if (ii + tx < DK16) {
5395 k4x4_t tmp;
5396 deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
5397 sk4x4[4*ty + tx] = tmp;
5398 }
5399
5400 simdgroup_barrier(mem_flags::mem_threadgroup);
5401
5402 for (short k = 0; k < 4 && ii + k < DK16; ++k) {
5403 k8x8_t mk;
5404 q8x8_t mq;
5405
5406 simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
5407 simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK);
5408 simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5409
5410 simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
5411 simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK);
5412 simdgroup_multiply_accumulate(mqk, mq, mk, mqk);
5413 }
5414 }
5415 }
5416
5417 simdgroup_store(mqk, ss + 8*cc, SH, 0, false);
5418 }
5419 }
5420
5421 threadgroup_barrier(mem_flags::mem_threadgroup);
5422
5423 // online softmax
5424 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5425 const short j = jj*NSG + sgitg;
5426
5427 const float m = M[jj];
5428
5429 // scale and apply the logitcap / mask
5430 float2 s2 = ss2[j*SH/2 + tiisg]*args.scale;
5431
5432 if (FC_flash_attn_ext_has_scap) {
5433 s2 = args.logit_softcap*precise::tanh(s2);
5434 }
5435
5436 // mqk = mqk + slope*mask
5437 if (blk_cur != 2) {
5438 if (FC_flash_attn_ext_has_bias) {
5439 s2 += s2_t(sm2[j*SH + tiisg])*slope;
5440 } else {
5441 s2 += s2_t(sm2[j*SH + tiisg]);
5442 }
5443 }
5444
5445 M[jj] = simd_max(max(M[jj], max(s2[0], s2[1])));
5446
5447 const float ms = exp(m - M[jj]);
5448 const float2 vs2 = exp(s2 - M[jj]);
5449
5450 S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]);
5451
5452 // the P matrix from the paper (Q rows, C columns)
5453 ss2[j*SH/2 + tiisg] = vs2;
5454
5455 if (DV4 % NW == 0) {
5456 FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
5457 const short i = ii*NW + tiisg;
5458
5459 so4[j*PV4 + i] *= ms;
5460 }
5461 } else {
5462 for (short i = tiisg; i < DV4; i += NW) {
5463 so4[j*PV4 + i] *= ms;
5464 }
5465 }
5466 }
5467
5468 threadgroup_barrier(mem_flags::mem_threadgroup);
5469
5470 // O = O + (Q*K^T)*V
5471 {
5472 // we can read directly from global memory
5473 if (is_same<vd4x4_t, v4x4_t>::value) {
5474 static_assert(PV8 % NSG == 0, "");
5475
5476 constexpr short NO = PV8/NSG;
5477
5478 o8x8_t lo[NO];
5479
5480 {
5481 auto sot = so + 8*sgitg;
5482
5483 FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
5484 simdgroup_load(lo[ii], sot, PV, 0, false);
5485
5486 sot += 8*NSG;
5487 }
5488 }
5489
5490 {
5491 device const v_t * pv = (device const v_t *) (v + ic*args.nb21);
5492
5493 pv += 8*sgitg;
5494
5495 if (DV <= 64) {
5496 FOR_UNROLL (short cc = 0; cc < C/8; ++cc) {
5497 s8x8_t vs;
5498 simdgroup_load(vs, ss + 8*cc, SH, 0, false);
5499
5500 FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
5501 v8x8_t mv[2];
5502
5503 simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false);
5504 simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false);
5505
5506 simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]);
5507 simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]);
5508 }
5509
5510 pv += 8*NS20;
5511 }
5512 } else {
5513 constexpr short NC = (C/8)/2;
5514
5515 FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
5516 s8x8_t vs[2];
5517
5518 simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
5519 simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false);
5520
5521 FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) {
5522 v8x8_t mv[4];
5523
5524 simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
5525 simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false);
5526 simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
5527 simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false);
5528
5529 simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]);
5530 simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]);
5531 simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]);
5532 simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]);
5533 }
5534
5535 pv += 2*8*NS20;
5536 }
5537 }
5538 }
5539
5540 {
5541 auto sot = so + 8*sgitg;
5542
5543 FOR_UNROLL (short ii = 0; ii < NO; ++ii) {
5544 simdgroup_store(lo[ii], sot, PV, 0, false);
5545
5546 sot += 8*NSG;
5547 }
5548 }
5549 } else {
5550 // TODO: this is the quantized V cache branch - not optimized yet
5551
5552 const short tx = tiisg%4;
5553 const short ty = tiisg/4;
5554
5555 for (short cc = 0; cc < C/8; ++cc) {
5556 s8x8_t vs;
5557 simdgroup_load(vs, ss + 8*cc, SH, 0, false);
5558
5559 for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) {
5560 device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21));
5561
5562 if (DV16%4 == 0) {
5563 // no need for bound checks
5564 {
5565 v4x4_t tmp;
5566 deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
5567 sv4x4[4*ty + tx] = tmp;
5568 }
5569
5570 simdgroup_barrier(mem_flags::mem_threadgroup);
5571
5572 FOR_UNROLL (short k = 0; k < 4; ++k) {
5573 v8x8_t mv[2];
5574 o8x8_t lo[2];
5575
5576 simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
5577 simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
5578 simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5579 simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5580
5581 simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
5582 simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
5583
5584 simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5585 simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5586 }
5587 } else {
5588 if (ii + tx < DV16) {
5589 v4x4_t tmp;
5590 deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
5591 sv4x4[4*ty + tx] = tmp;
5592 }
5593
5594 simdgroup_barrier(mem_flags::mem_threadgroup);
5595
5596 for (short k = 0; k < 4 && ii + k < DV16; ++k) {
5597 v8x8_t mv[2];
5598 o8x8_t lo[2];
5599
5600 simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false);
5601 simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false);
5602 simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5603 simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5604
5605 simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]);
5606 simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]);
5607
5608 simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false);
5609 simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false);
5610 }
5611 }
5612 }
5613 }
5614 }
5615 }
5616
5617 threadgroup_barrier(mem_flags::mem_threadgroup);
5618 }
5619
5620 if (FC_flash_attn_ext_has_sinks) {
5621 FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
5622 const short j = jj*NSG + sgitg;
5623
5624 const float m = M[jj];
5625 const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
5626
5627 M[jj] = simd_max(max(M[jj], s));
5628
5629 const float ms = exp(m - M[jj]);
5630 const float vs = exp(s - M[jj]);
5631
5632 S[jj] = S[jj]*ms + simd_sum(vs);
5633
5634 for (short i = tiisg; i < DV4; i += NW) {
5635 so4[j*PV4 + i] *= ms;
5636 }
5637 }
5638 }
5639 }
5640
5641 // store to global memory
5642 for (short jj = 0; jj < NQ; ++jj) {
5643 const short j = jj*NSG + sgitg;
5644 if (iq1 + j >= args.ne01) {
5645 break;
5646 }
5647
5648 device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
5649
5650 const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
5651
5652 if (DV4 % NW == 0) {
5653 FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
5654 const short i = ii*NW + tiisg;
5655
5656 dst4[i] = (float4) so4[j*PV4 + i]*scale;
5657 }
5658 } else {
5659 for (short i = tiisg; i < DV4; i += NW) {
5660 dst4[i] = (float4) so4[j*PV4 + i]*scale;
5661 }
5662 }
5663 }
5664
5665#undef NS10
5666#undef NS20
5667}
5668
5669template<
5670 typename q_t, // query types in shared memory
5671 typename q4_t,
5672 typename q8x8_t,
5673 typename k_t, // key types in shared memory
5674 typename k4x4_t,
5675 typename k8x8_t,
5676 typename v_t, // value types in shared memory
5677 typename v4x4_t,
5678 typename v8x8_t,
5679 typename qk_t, // Q*K types
5680 typename qk8x8_t,
5681 typename s_t, // soft-max types
5682 typename s2_t,
5683 typename s8x8_t,
5684 typename o_t, // attention accumulation types
5685 typename o4_t,
5686 typename o8x8_t,
5687 typename kd4x4_t, // key type in device memory
5688 short nl_k,
5689 void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
5690 typename vd4x4_t, // value type in device memory
5691 short nl_v,
5692 void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
5693 short DK, // K head size
5694 short DV, // V head size
5695 short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup
5696 short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup
5697kernel void kernel_flash_attn_ext(
5698 constant ggml_metal_kargs_flash_attn_ext & args,
5699 device const char * q,
5700 device const char * k,
5701 device const char * v,
5702 device const char * mask,
5703 device const char * sinks,
5704 device const char * pad,
5705 device const char * blk,
5706 device char * dst,
5707 threadgroup half * shmem_f16 [[threadgroup(0)]],
5708 uint3 tgpig[[threadgroup_position_in_grid]],
5709 ushort tiisg[[thread_index_in_simdgroup]],
5710 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5711#define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C
5712#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg
5713 switch (FC_flash_attn_ext_nsg) {
5714 // note: disabled cases to reduce library load time
5715 //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
5716 //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
5717 case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
5718 case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
5719 }
5720#undef FWD_TMPL
5721#undef FWD_ARGS
5722}
5723
5724// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
5725// template to be able to explore different combinations
5726//
5727#define FA_TYPES \
5728 half, half4, simdgroup_half8x8, \
5729 half, half4x4, simdgroup_half8x8, \
5730 half, half4x4, simdgroup_half8x8, \
5731 float, simdgroup_float8x8, \
5732 float, float2, simdgroup_float8x8, \
5733 float, float4, simdgroup_float8x8
5734 //half, half4, simdgroup_half8x8
5735
5736#define FA_TYPES_BF \
5737 bfloat, bfloat4, simdgroup_bfloat8x8, \
5738 bfloat, bfloat4x4, simdgroup_bfloat8x8, \
5739 bfloat, bfloat4x4, simdgroup_bfloat8x8, \
5740 float, simdgroup_float8x8, \
5741 float, float2, simdgroup_float8x8, \
5742 half, half4, simdgroup_half8x8
5743 //float, float4, simdgroup_float8x8
5744
5745#define FA_TYPES_F32 \
5746 half, half4, simdgroup_half8x8, \
5747 float, float4x4, simdgroup_float8x8, \
5748 float, float4x4, simdgroup_float8x8, \
5749 float, simdgroup_float8x8, \
5750 float, float2, simdgroup_float8x8, \
5751 float, float4, simdgroup_float8x8
5752 //half, half4, simdgroup_half8x8
5753
5754typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
5755
5756template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
5757template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
5758template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 48, 48>;
5759template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
5760template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>;
5761template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
5762template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
5763template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
5764template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
5765template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
5766template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
5767template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
5768template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
5769
5770template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
5771template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
5772template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 48, 48>;
5773template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
5774template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>;
5775template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
5776template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>;
5777template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>;
5778template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>;
5779template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>;
5780template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>;
5781template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>;
5782template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
5783
5784#if defined(GGML_METAL_HAS_BF16)
5785template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
5786template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
5787template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 48, 48>;
5788template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
5789template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>;
5790template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
5791template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
5792template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
5793template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
5794template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
5795template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
5796template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
5797template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
5798#endif
5799
5800template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
5801template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
5802template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 48, 48>;
5803template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
5804template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>;
5805template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
5806template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>;
5807template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>;
5808template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>;
5809template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>;
5810template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>;
5811template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
5812template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
5813
5814template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
5815template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
5816template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 48, 48>;
5817template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
5818template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>;
5819template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
5820template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>;
5821template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>;
5822template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>;
5823template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>;
5824template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>;
5825template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
5826template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
5827
5828template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
5829template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
5830template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 48, 48>;
5831template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
5832template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>;
5833template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
5834template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>;
5835template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>;
5836template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>;
5837template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>;
5838template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>;
5839template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
5840template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
5841
5842template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
5843template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
5844template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 48, 48>;
5845template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
5846template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>;
5847template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
5848template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>;
5849template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>;
5850template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>;
5851template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>;
5852template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>;
5853template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
5854template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
5855
5856template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
5857template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
5858template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 48, 48>;
5859template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
5860template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>;
5861template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
5862template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>;
5863template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>;
5864template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>;
5865template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>;
5866template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>;
5867template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>;
5868template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
5869
5870#undef FA_TYPES
5871#undef FA_TYPES_BF
5872#undef FA_TYPES_F32
5873
5874constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]];
5875constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]];
5876constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]];
5877constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]];
5878constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]];
5879
5880//constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]];
5881//constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]];
5882//constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]];
5883
5884constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]];
5885constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]];
5886constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]];
5887constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]];
5888
5889template<
5890 typename q4_t, // query types in shared memory
5891 typename k4_t, // key types in shared memory
5892 typename v4_t, // value types in shared memory
5893 typename qk_t, // Q*K types
5894 typename s_t, // soft-max types
5895 typename s4_t,
5896 typename o4_t, // attention accumulation types
5897 typename kd4_t, // key type in device memory
5898 short nl_k,
5899 void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &),
5900 typename vd4_t, // value type in device memory
5901 short nl_v,
5902 void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &),
5903 short DK, // K head size
5904 short DV, // V head size
5905 short NE = 4, // head elements per thread
5906 short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup
5907 short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup
5908kernel void kernel_flash_attn_ext_vec(
5909 constant ggml_metal_kargs_flash_attn_ext_vec & args,
5910 device const char * q,
5911 device const char * k,
5912 device const char * v,
5913 device const char * mask,
5914 device const char * sinks,
5915 device const char * pad,
5916 device char * dst,
5917 threadgroup half * shmem_f16 [[threadgroup(0)]],
5918 uint3 tgpig[[threadgroup_position_in_grid]],
5919 ushort tiisg[[thread_index_in_simdgroup]],
5920 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
5921 static_assert(DK % 32 == 0, "DK must be divisible by 32");
5922 static_assert(DV % 32 == 0, "DV must be divisible by 32");
5923
5924#define NWG (FC_flash_attn_ext_vec_nwg)
5925#define NSG (FC_flash_attn_ext_vec_nsg)
5926
5927#define NS10 (FC_flash_attn_ext_vec_ns10)
5928#define NS20 (FC_flash_attn_ext_vec_ns20)
5929
5930 const short iwg = tgpig[2]%NWG;
5931
5932 const ushort iq3 = tgpig[2]/NWG;
5933 const ushort iq2 = tgpig[1];
5934 const ushort iq1 = tgpig[0];
5935
5936 constexpr short DK4 = DK/4;
5937 constexpr short DV4 = DV/4;
5938
5939 constexpr short PK = PAD2(DK, 128);
5940 constexpr short PK4 = PK/4;
5941
5942 constexpr short PV = PAD2(DV, 128);
5943 constexpr short PV4 = PV/4;
5944
5945 constexpr short NW = N_SIMDWIDTH;
5946 constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
5947 constexpr short SH = 4*C; // shared memory per simdgroup
5948
5949 static_assert(DK4 % NL == 0, "DK4 must be divisible by NL");
5950 static_assert(DV4 % NL == 0, "DV4 must be divisible by NL");
5951
5952 const short T = PK + NSG*SH; // shared memory size per query in (half)
5953
5954 //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data
5955 threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t
5956 threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention
5957 threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t
5958 threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask
5959 threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results
5960
5961 // store the result for all queries in shared memory (the O matrix from the paper)
5962 so4 += tiisg;
5963
5964 {
5965 q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03;
5966
5967 const short ikv2 = iq2/(args.ne02/args.ne_12_2);
5968 const short ikv3 = iq3/(args.ne03/args.ne_12_3);
5969
5970 k += ikv2*args.nb12 + ikv3*args.nb13;
5971 v += ikv2*args.nb22 + ikv3*args.nb23;
5972 }
5973
5974 // load heads from Q to shared memory
5975 device const float4 * q4 = (device const float4 *) ((device const char *) q);
5976
5977 if (iq1 < args.ne01) {
5978 for (short i = tiisg; i < PK4; i += NW) {
5979 if (i < DK4) {
5980 sq4[i] = (q4_t) q4[i];
5981 } else {
5982 sq4[i] = (q4_t) 0.0f;
5983 }
5984 }
5985 }
5986
5987 // zero out so
5988 for (short i = 0; i < DV4/NL; ++i) {
5989 so4[i*NL] = (o4_t) 0.0f;
5990 }
5991
5992 // zero out shared memory SH
5993 for (short i = tiisg; i < SH/4; i += NW) {
5994 ss4[i] = (s4_t) 0.0f;
5995 }
5996
5997 threadgroup_barrier(mem_flags::mem_threadgroup);
5998
5999 {
6000 float S = 0.0f;
6001 float M = -FLT_MAX/2;
6002
6003 // thread indices inside the simdgroup
6004 const short tx = tiisg%NL;
6005 const short ty = tiisg/NL;
6006
6007 // pointer to the mask
6008 device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
6009
6010 float slope = 1.0f;
6011
6012 // ALiBi
6013 if (FC_flash_attn_ext_vec_has_bias) {
6014 const short h = iq2;
6015
6016 const float base = h < args.n_head_log2 ? args.m0 : args.m1;
6017 const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
6018
6019 slope = pow(base, exph);
6020 }
6021
6022 // loop over the KV cache
6023 // each simdgroup handles blocks of Q rows and C columns
6024 for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) {
6025 int ic = ic0*C;
6026 if (ic >= args.ne11) {
6027 break;
6028 }
6029
6030 // the last partial chunk uses the pad buffer as source
6031 if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) {
6032 k = pad;
6033 v = k + args.nb11*C*args.ne_12_2*args.ne_12_3;
6034 mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3;
6035
6036 const short ikv2 = iq2/(args.ne02/args.ne_12_2);
6037 const short ikv3 = iq3/(args.ne03/args.ne_12_3);
6038
6039 k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C;
6040 v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
6041
6042 if (!FC_flash_attn_ext_vec_has_mask) {
6043 if (ic + tiisg >= args.ne11) {
6044 sm[tiisg] = -MAXHALF;
6045 }
6046 } else {
6047 pm = (device const half *) (mask) +
6048 iq1*C +
6049 (iq2%args.ne32)*(C*args.ne31) +
6050 (iq3%args.ne33)*(C*args.ne31*args.ne32);
6051 }
6052
6053 ic = 0;
6054 }
6055
6056 if (FC_flash_attn_ext_vec_has_mask) {
6057 sm[tiisg] = pm[ic + tiisg];
6058 }
6059
6060 // skip -INF blocks
6061 if (simd_max(sm[tiisg]) <= -MAXHALF) {
6062 continue;
6063 }
6064
6065 // Q*K^T
6066 {
6067 device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11);
6068 threadgroup const q4_t * pq4 = sq4;
6069
6070 pk4 += ty*NS10/4 + tx;
6071 pq4 += tx;
6072
6073 qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f };
6074
6075 // each simdgroup processes 1 query and NE (NW/NL) cache elements
6076 FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
6077 if (is_same<kd4_t, k4_t>::value) {
6078 FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
6079 mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]);
6080 }
6081 } else {
6082 device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11));
6083
6084 k4_t mk;
6085
6086 FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) {
6087 const short i = ii*NL + tx;
6088
6089 deq_k_t4(pk + i/nl_k, i%nl_k, mk);
6090
6091 mqk[cc] += dot((float4) mk, (float4) sq4[i]);
6092 }
6093 }
6094
6095 if (NE == 1) {
6096 mqk[cc] = simd_sum(mqk[cc]);
6097 } else {
6098 // simdgroup reduce (NE = 4)
6099 // [ 0 .. 7] -> [ 0]
6100 // [ 8 .. 15] -> [ 8]
6101 // [16 .. 23] -> [16]
6102 // [24 .. 31] -> [24]
6103 if (NE <= 1) {
6104 mqk[cc] += simd_shuffle_down(mqk[cc], 16);
6105 }
6106 if (NE <= 2) {
6107 mqk[cc] += simd_shuffle_down(mqk[cc], 8);
6108 }
6109 if (NE <= 4) {
6110 mqk[cc] += simd_shuffle_down(mqk[cc], 4);
6111 }
6112 if (NE <= 8) {
6113 mqk[cc] += simd_shuffle_down(mqk[cc], 2);
6114 }
6115 if (NE <= 16) {
6116 mqk[cc] += simd_shuffle_down(mqk[cc], 1);
6117 }
6118
6119 // broadcast
6120 mqk[cc] = simd_shuffle(mqk[cc], NL*ty);
6121 }
6122 }
6123
6124 if (FC_flash_attn_ext_vec_has_mask &&
6125 !FC_flash_attn_ext_vec_has_scap &&
6126 !FC_flash_attn_ext_vec_has_bias) {
6127 ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]);
6128 } else {
6129 mqk[tx] *= args.scale;
6130
6131 if (FC_flash_attn_ext_vec_has_scap) {
6132 mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]);
6133 }
6134
6135 if (FC_flash_attn_ext_vec_has_bias) {
6136 mqk[tx] += (qk_t) sm[NE*tx + ty]*slope;
6137 } else {
6138 mqk[tx] += (qk_t) sm[NE*tx + ty];
6139 }
6140
6141 ss[NE*tx + ty] = mqk[tx];
6142 }
6143 }
6144
6145 simdgroup_barrier(mem_flags::mem_threadgroup);
6146
6147 // online softmax
6148 {
6149 const float m = M;
6150 const float s = ss[tiisg];
6151
6152 M = simd_max(max(M, s));
6153
6154 const float ms = exp(m - M);
6155 const float vs = exp(s - M);
6156
6157 S = S*ms + simd_sum(vs);
6158
6159 // the P matrix from the paper (Q rows, C columns)
6160 ss[tiisg] = vs;
6161
6162 // O = diag(ms)*O
6163 if ((DV4/NL % NW == 0) || ty == 0) {
6164 FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6165 so4[ii*NL] *= ms;
6166 }
6167 }
6168 }
6169
6170 simdgroup_barrier(mem_flags::mem_threadgroup);
6171
6172 // O = O + (Q*K^T)*V
6173 {
6174 o4_t lo[DV4/NL];
6175 FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6176 lo[ii] = 0.0f;
6177 }
6178
6179 if (is_same<vd4_t, v4_t>::value) {
6180 device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21);
6181
6182 pv4 += ty*NS20/4 + tx;
6183
6184 const auto sst = ss + ty;
6185
6186 FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
6187 FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6188 lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE]));
6189 }
6190 }
6191 } else {
6192 FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) {
6193 device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21));
6194
6195 FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6196 const short i = ii*NL + tx;
6197
6198 v4_t mv;
6199 deq_v_t4(pv4 + i/nl_v, i%nl_v, mv);
6200
6201 lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty]));
6202 }
6203 }
6204 }
6205
6206 FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6207 if (NE > 1) {
6208 lo[ii][0] += simd_shuffle_down(lo[ii][0], 16);
6209 lo[ii][1] += simd_shuffle_down(lo[ii][1], 16);
6210 lo[ii][2] += simd_shuffle_down(lo[ii][2], 16);
6211 lo[ii][3] += simd_shuffle_down(lo[ii][3], 16);
6212 }
6213
6214 if (NE > 2) {
6215 lo[ii][0] += simd_shuffle_down(lo[ii][0], 8);
6216 lo[ii][1] += simd_shuffle_down(lo[ii][1], 8);
6217 lo[ii][2] += simd_shuffle_down(lo[ii][2], 8);
6218 lo[ii][3] += simd_shuffle_down(lo[ii][3], 8);
6219 }
6220
6221 if (NE > 4) {
6222 lo[ii][0] += simd_shuffle_down(lo[ii][0], 4);
6223 lo[ii][1] += simd_shuffle_down(lo[ii][1], 4);
6224 lo[ii][2] += simd_shuffle_down(lo[ii][2], 4);
6225 lo[ii][3] += simd_shuffle_down(lo[ii][3], 4);
6226 }
6227
6228 if (NE > 8) {
6229 lo[ii][0] += simd_shuffle_down(lo[ii][0], 2);
6230 lo[ii][1] += simd_shuffle_down(lo[ii][1], 2);
6231 lo[ii][2] += simd_shuffle_down(lo[ii][2], 2);
6232 lo[ii][3] += simd_shuffle_down(lo[ii][3], 2);
6233 }
6234
6235 if (NE > 16) {
6236 lo[ii][0] += simd_shuffle_down(lo[ii][0], 1);
6237 lo[ii][1] += simd_shuffle_down(lo[ii][1], 1);
6238 lo[ii][2] += simd_shuffle_down(lo[ii][2], 1);
6239 lo[ii][3] += simd_shuffle_down(lo[ii][3], 1);
6240 }
6241 }
6242
6243 if ((DV4/NL % NW == 0) || ty == 0) {
6244 FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6245 so4[ii*NL] += lo[ii];
6246 }
6247 }
6248 }
6249 }
6250
6251 if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) {
6252 const float m = M;
6253 const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
6254
6255 M = simd_max(max(M, s));
6256
6257 const float ms = exp(m - M);
6258 const float vs = exp(s - M);
6259
6260 S = S*ms + simd_sum(vs);
6261
6262 if ((DV4/NL % NW == 0) || ty == 0) {
6263 FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) {
6264 so4[ii*NL] *= ms;
6265 }
6266 }
6267 }
6268
6269 // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
6270 if (tiisg == 0) {
6271 ss[0] = (s_t) S;
6272 ss[1] = (s_t) M;
6273 }
6274 }
6275
6276 so4 -= tiisg;
6277
6278 threadgroup_barrier(mem_flags::mem_threadgroup);
6279
6280 // parallel reduce
6281 for (short r = NSG/2; r > 0; r >>= 1) {
6282 if (sgitg < r) {
6283 const float S0 = ss[ 0];
6284 const float S1 = ss[r*(SH/2) + 0];
6285
6286 const float M0 = ss[ 1];
6287 const float M1 = ss[r*(SH/2) + 1];
6288
6289 const float M = max(M0, M1);
6290
6291 const float ms0 = exp(M0 - M);
6292 const float ms1 = exp(M1 - M);
6293
6294 const float S = S0*ms0 + S1*ms1;
6295
6296 if (tiisg == 0) {
6297 ss[0] = S;
6298 ss[1] = M;
6299 }
6300
6301 // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
6302 for (short i = tiisg; i < DV4; i += NW) {
6303 so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1;
6304 }
6305 }
6306
6307 threadgroup_barrier(mem_flags::mem_threadgroup);
6308 }
6309
6310 // final rescale with 1/S and store to global memory
6311 if (sgitg == 0) {
6312 const int64_t nrows = args.ne3*args.ne2*args.ne1;
6313 const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
6314
6315 device float4 * dst4 = (device float4 *) dst;
6316 device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
6317
6318 const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
6319
6320 // interleave the workgroup data
6321 for (short i = tiisg; i < DV4; i += NW) {
6322 dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S;
6323 }
6324
6325 // store S and M
6326 if (NWG > 1) {
6327 if (tiisg == 0) {
6328 dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0];
6329 dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1];
6330 }
6331 }
6332 }
6333
6334#undef NWG
6335#undef NSG
6336#undef NS10
6337#undef NS20
6338}
6339
6340// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
6341// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
6342//
6343#define FA_TYPES \
6344 half4, \
6345 half4, \
6346 half4, \
6347 float, \
6348 float, float4, \
6349 float4
6350
6351#define FA_TYPES_F32 \
6352 half4, \
6353 float4, \
6354 float4, \
6355 float, \
6356 float, float4, \
6357 float4
6358
6359typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
6360
6361template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
6362template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
6363#if defined(GGML_METAL_HAS_BF16)
6364template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>;
6365#endif
6366template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
6367template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
6368template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
6369template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
6370template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
6371
6372template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
6373template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
6374#if defined(GGML_METAL_HAS_BF16)
6375template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>;
6376#endif
6377template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
6378template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
6379template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
6380template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
6381template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
6382
6383template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
6384template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
6385#if defined(GGML_METAL_HAS_BF16)
6386template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>;
6387#endif
6388template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
6389template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
6390template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
6391template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
6392template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
6393
6394template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
6395template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
6396#if defined(GGML_METAL_HAS_BF16)
6397template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>;
6398#endif
6399template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
6400template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
6401template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
6402template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
6403template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
6404
6405template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
6406template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
6407#if defined(GGML_METAL_HAS_BF16)
6408template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>;
6409#endif
6410template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
6411template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
6412template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
6413template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
6414template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
6415
6416template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
6417template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
6418#if defined(GGML_METAL_HAS_BF16)
6419template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>;
6420#endif
6421template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
6422template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
6423template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
6424template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
6425template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
6426
6427template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
6428template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
6429#if defined(GGML_METAL_HAS_BF16)
6430template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
6431#endif
6432template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
6433template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
6434template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
6435template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
6436template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
6437
6438template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
6439template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
6440#if defined(GGML_METAL_HAS_BF16)
6441template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
6442#endif
6443template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
6444template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
6445template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
6446template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
6447template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
6448
6449#undef FA_TYPES
6450#undef FA_TYPES_F32
6451
6452constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]];
6453constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]];
6454
6455kernel void kernel_flash_attn_ext_vec_reduce(
6456 constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args,
6457 device const char * htmp,
6458 device char * dst,
6459 uint tgpig[[threadgroup_position_in_grid]],
6460 ushort tiisg[[thread_index_in_simdgroup]],
6461 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6462#define NWG (FC_flash_attn_ext_vec_reduce_NWG)
6463#define DV (FC_flash_attn_ext_vec_reduce_DV)
6464
6465 const uint64_t rid = tgpig;
6466
6467 const short iwg = tiisg;
6468
6469 device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG;
6470
6471 float S = ss[rid*(2*NWG) + 2*iwg + 0];
6472 float M = ss[rid*(2*NWG) + 2*iwg + 1];
6473
6474 const float m = simd_max(M);
6475 const float ms = exp(M - m);
6476
6477 S = simd_sum(S*ms);
6478 S = S == 0.0f ? 0.0f : 1.0f/S;
6479
6480 const short DV4 = DV/4;
6481
6482 device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG;
6483 device float4 * dst4 = (device float4 *) dst + rid*DV4;
6484
6485 for (short i = sgitg; i < DV4; i += NWG) {
6486 const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms);
6487
6488 if (iwg == 0) {
6489 dst4[i] = v*S;
6490 }
6491 }
6492
6493#undef NWG
6494#undef DV
6495}
6496
6497template<typename T0, typename T1>
6498kernel void kernel_cpy_t_t(
6499 constant ggml_metal_kargs_cpy & args,
6500 device const char * src0,
6501 device char * dst,
6502 uint3 tgpig[[threadgroup_position_in_grid]],
6503 ushort tiitg[[thread_index_in_threadgroup]],
6504 ushort3 ntg[[threads_per_threadgroup]]) {
6505 const int i03 = tgpig[2];
6506 const int i02 = tgpig[1];
6507 const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6508 const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
6509
6510 const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6511
6512 const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
6513 const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
6514 const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
6515 const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
6516
6517 device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6518
6519 for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
6520 device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
6521 dst_data[i00] = (T1) src[0];
6522 break;
6523 }
6524}
6525
6526typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
6527
6528template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
6529template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
6530template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
6531template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
6532template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>;
6533#if defined(GGML_METAL_HAS_BF16)
6534template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
6535#endif
6536template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
6537template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
6538#if defined(GGML_METAL_HAS_BF16)
6539template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
6540template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
6541#endif
6542
6543template<short QK,
6544 typename block_q,
6545 void (*quantize_func)(device const float *, device block_q &)>
6546kernel void kernel_cpy_f32_q(
6547 constant ggml_metal_kargs_cpy & args,
6548 device const char * src0,
6549 device char * dst,
6550 uint3 tgpig[[threadgroup_position_in_grid]],
6551 ushort tiitg[[thread_index_in_threadgroup]],
6552 ushort3 ntg[[threads_per_threadgroup]]) {
6553 const int i03 = tgpig[2];
6554 const int i02 = tgpig[1];
6555 const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6556 const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
6557
6558 const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6559
6560 const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
6561 const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
6562 const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
6563 const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
6564
6565 device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6566
6567 for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
6568 device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
6569
6570 quantize_func(src, dst_data[i00]);
6571
6572 break;
6573 }
6574}
6575
6576typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
6577
6578template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
6579template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
6580template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
6581template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
6582template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
6583template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
6584
6585template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
6586kernel void kernel_cpy_q_f32(
6587 constant ggml_metal_kargs_cpy & args,
6588 device const char * src0,
6589 device char * dst,
6590 uint3 tgpig[[threadgroup_position_in_grid]],
6591 ushort tiitg[[thread_index_in_threadgroup]],
6592 ushort3 ntg[[threads_per_threadgroup]]) {
6593 const int i03 = tgpig[2];
6594 const int i02 = tgpig[1];
6595 const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
6596 const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
6597
6598 const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
6599
6600 const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
6601 const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
6602 const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
6603 const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
6604
6605 device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
6606 device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6607
6608 for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
6609 T4x4 temp;
6610 dequantize_func(src_data + i00/nl, i00%nl, temp);
6611 dst_data[i00] = temp;
6612
6613 break;
6614 }
6615}
6616
6617typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t;
6618
6619template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>;
6620template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>;
6621template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>;
6622template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>;
6623template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>;
6624
6625template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>;
6626template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>;
6627template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>;
6628template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>;
6629template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>;
6630
6631kernel void kernel_concat(
6632 constant ggml_metal_kargs_concat & args,
6633 device const char * src0,
6634 device const char * src1,
6635 device char * dst,
6636 uint3 tgpig[[threadgroup_position_in_grid]],
6637 ushort3 tpitg[[thread_position_in_threadgroup]],
6638 ushort3 ntg[[threads_per_threadgroup]]) {
6639
6640 const int i3 = tgpig.z;
6641 const int i2 = tgpig.y;
6642 const int i1 = tgpig.x;
6643
6644 int o[4] = {0, 0, 0, 0};
6645 o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
6646
6647 device const float * x;
6648
6649 for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
6650 if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
6651 x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
6652 } else {
6653 x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
6654 }
6655
6656 device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
6657
6658 *y = *x;
6659 }
6660}
6661
6662template<int nr0, typename args_t>
6663void kernel_mul_mv_q2_K_f32_impl(
6664 args_t args,
6665 device const char * src0,
6666 device const char * src1,
6667 device char * dst,
6668 threadgroup char * shmem,
6669 uint3 tgpig,
6670 ushort tiisg,
6671 ushort sgitg) {
6672 const short NSG = FC_mul_mv_nsg;
6673
6674 const int nb = args.ne00/QK_K;
6675
6676 const int r0 = tgpig.x;
6677 const int r1 = tgpig.y;
6678 const int im = tgpig.z;
6679
6680 const int first_row = (r0 * NSG + sgitg) * nr0;
6681
6682 const uint i12 = im%args.ne12;
6683 const uint i13 = im/args.ne12;
6684
6685 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6686 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
6687
6688 device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
6689 device const float * y = (device const float *) (src1 + offset1);
6690
6691 float yl[32];
6692 float sumf[nr0]={0.f};
6693
6694 const short ix = tiisg/8; // 0...3
6695 const short it = tiisg%8; // 0...7
6696 const short iq = it/4; // 0 or 1
6697 const short ir = it%4; // 0...3
6698 const short is = (8*ir)/16;// 0 or 1
6699
6700 device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
6701
6702 for (int ib = ix; ib < nb; ib += 4) {
6703 float4 sumy = {0.f, 0.f, 0.f, 0.f};
6704 for (short i = 0; i < 8; ++i) {
6705 yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
6706 yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
6707 yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
6708 yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
6709 }
6710
6711 device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is;
6712 device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
6713 device const half * dh = &x[ib].d;
6714
6715 for (short row = 0; row < nr0; row++) {
6716 float4 acc1 = {0.f, 0.f, 0.f, 0.f};
6717 float4 acc2 = {0.f, 0.f, 0.f, 0.f};
6718 for (int i = 0; i < 8; i += 2) {
6719 acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
6720 acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
6721 acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
6722 acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
6723 acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
6724 acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
6725 acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
6726 acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
6727 }
6728 float dall = dh[0];
6729 float dmin = dh[1] * 1.f/16.f;
6730 sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
6731 (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
6732 (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
6733 (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
6734 dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
6735
6736 qs += args.nb01/2;
6737 sc += args.nb01;
6738 dh += args.nb01/2;
6739 }
6740
6741 y4 += 4 * QK_K;
6742 }
6743
6744 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
6745
6746 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6747 float sum_all = simd_sum(sumf[row]);
6748 if (tiisg == 0) {
6749 dst_f32[first_row + row] = sum_all;
6750 }
6751 }
6752}
6753
6754[[host_name("kernel_mul_mv_q2_K_f32")]]
6755kernel void kernel_mul_mv_q2_K_f32(
6756 constant ggml_metal_kargs_mul_mv & args,
6757 device const char * src0,
6758 device const char * src1,
6759 device char * dst,
6760 uint3 tgpig[[threadgroup_position_in_grid]],
6761 ushort tiisg[[thread_index_in_simdgroup]],
6762 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6763
6764 kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
6765}
6766
6767template<int nr0, typename args_t>
6768void kernel_mul_mv_q3_K_f32_impl(
6769 args_t args,
6770 device const char * src0,
6771 device const char * src1,
6772 device char * dst,
6773 threadgroup char * shmem,
6774 uint3 tgpig,
6775 ushort tiisg,
6776 ushort sgitg) {
6777 const short NSG = FC_mul_mv_nsg;
6778
6779 const int nb = args.ne00/QK_K;
6780
6781 const int r0 = tgpig.x;
6782 const int r1 = tgpig.y;
6783 const int im = tgpig.z;
6784
6785 const int first_row = (r0 * NSG + sgitg) * nr0;
6786
6787 const uint i12 = im%args.ne12;
6788 const uint i13 = im/args.ne12;
6789
6790 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6791 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
6792
6793 device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
6794 device const float * yy = (device const float *) (src1 + offset1);
6795
6796 float yl[32];
6797
6798 //const uint16_t kmask1 = 0x3030;
6799 //const uint16_t kmask2 = 0x0f0f;
6800
6801 const short tid = tiisg/4;
6802 const short ix = tiisg%4;
6803 const short ip = tid/4; // 0 or 1
6804 const short il = 2*((tid%4)/2); // 0 or 2
6805 const short ir = tid%2;
6806 const short l0 = 8*ir;
6807
6808 // One would think that the Metal compiler would figure out that ip and il can only have
6809 // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
6810 // with these two tales.
6811 //
6812 // Possible masks for the high bit
6813 const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
6814 {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
6815 {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
6816 {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
6817
6818 // Possible masks for the low 2 bits
6819 const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
6820
6821 const ushort4 hm = mm[2*ip + il/2];
6822
6823 const short shift = 2*il;
6824
6825 const float v1 = il == 0 ? 4.f : 64.f;
6826 const float v2 = 4.f * v1;
6827
6828 const uint16_t s_shift1 = 4*ip;
6829 const uint16_t s_shift2 = s_shift1 + il;
6830
6831 const short q_offset = 32*ip + l0;
6832 const short y_offset = 128*ip + 32*il + l0;
6833
6834 device const float * y1 = yy + ix*QK_K + y_offset;
6835
6836 uint32_t scales32, aux32;
6837 thread uint16_t * scales16 = (thread uint16_t *)&scales32;
6838 thread const int8_t * scales = (thread const int8_t *)&scales32;
6839
6840 float sumf1[nr0] = {0.f};
6841 float sumf2[nr0] = {0.f};
6842
6843 for (int i = ix; i < nb; i += 4) {
6844 for (short l = 0; l < 8; ++l) {
6845 yl[l+ 0] = y1[l+ 0];
6846 yl[l+ 8] = y1[l+16];
6847 yl[l+16] = y1[l+32];
6848 yl[l+24] = y1[l+48];
6849 }
6850
6851 device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
6852 device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
6853 device const uint16_t * a = (device const uint16_t *)(x[i].scales);
6854 device const half * dh = &x[i].d;
6855
6856 for (short row = 0; row < nr0; ++row) {
6857 const float d_all = (float)dh[0];
6858
6859 scales16[0] = a[4];
6860 scales16[1] = a[5];
6861 aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
6862 scales16[0] = a[il+0];
6863 scales16[1] = a[il+1];
6864 scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
6865
6866 float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
6867 for (short l = 0; l < 8; l += 2) {
6868 const int32_t qs = q[l/2];
6869 s1 += yl[l+0] * (qs & qm[il/2][0]);
6870 s2 += yl[l+1] * (qs & qm[il/2][1]);
6871 s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
6872 s4 += yl[l+16] * (qs & qm[il/2][2]);
6873 s5 += yl[l+17] * (qs & qm[il/2][3]);
6874 s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
6875 }
6876 float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
6877 float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
6878 sumf1[row] += d1 * (scales[0] - 32);
6879 sumf2[row] += d2 * (scales[2] - 32);
6880
6881 s1 = s2 = s3 = s4 = s5 = s6 = 0;
6882 for (short l = 0; l < 8; l += 2) {
6883 const int32_t qs = q[l/2+8];
6884 s1 += yl[l+8] * (qs & qm[il/2][0]);
6885 s2 += yl[l+9] * (qs & qm[il/2][1]);
6886 s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
6887 s4 += yl[l+24] * (qs & qm[il/2][2]);
6888 s5 += yl[l+25] * (qs & qm[il/2][3]);
6889 s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
6890 }
6891 d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
6892 d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
6893 sumf1[row] += d1 * (scales[1] - 32);
6894 sumf2[row] += d2 * (scales[3] - 32);
6895
6896 q += args.nb01/2;
6897 h += args.nb01/2;
6898 a += args.nb01/2;
6899 dh += args.nb01/2;
6900 }
6901
6902 y1 += 4 * QK_K;
6903 }
6904
6905 for (int row = 0; row < nr0; ++row) {
6906 const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
6907 sumf1[row] = simd_sum(sumf);
6908 }
6909
6910 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
6911
6912 if (tiisg == 0) {
6913 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
6914 dst_f32[first_row + row] = sumf1[row];
6915 }
6916 }
6917}
6918
6919[[host_name("kernel_mul_mv_q3_K_f32")]]
6920kernel void kernel_mul_mv_q3_K_f32(
6921 constant ggml_metal_kargs_mul_mv & args,
6922 device const char * src0,
6923 device const char * src1,
6924 device char * dst,
6925 uint3 tgpig[[threadgroup_position_in_grid]],
6926 ushort tiisg[[thread_index_in_simdgroup]],
6927 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
6928
6929 kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
6930}
6931
6932template<int nr0, typename args_t>
6933void kernel_mul_mv_q4_K_f32_impl(
6934 args_t args,
6935 device const char * src0,
6936 device const char * src1,
6937 device char * dst,
6938 threadgroup char * shmem,
6939 uint3 tgpig,
6940 ushort tiisg,
6941 ushort sgitg) {
6942 const short NSG = FC_mul_mv_nsg;
6943
6944 constexpr uint16_t kmask1 = 0x3f3f;
6945 constexpr uint16_t kmask2 = 0x0f0f;
6946 constexpr uint16_t kmask3 = 0xc0c0;
6947
6948 const short ix = tiisg/8; // 0...3
6949 const short it = tiisg%8; // 0...7
6950 const short iq = it/4; // 0 or 1
6951 const short ir = it%4; // 0...3
6952
6953 const int nb = args.ne00/QK_K;
6954
6955 const int r0 = tgpig.x;
6956 const int r1 = tgpig.y;
6957 const int im = tgpig.z;
6958
6959 const int first_row = (r0 * NSG + sgitg) * nr0;
6960
6961 const uint i12 = im%args.ne12;
6962 const uint i13 = im/args.ne12;
6963
6964 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
6965 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
6966
6967 device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
6968 device const float * y = (device const float *) (src1 + offset1);
6969
6970 float yl[16];
6971 float yh[16];
6972
6973 float sumf[nr0]={0.f};
6974
6975 device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
6976
6977 uint16_t sc16[4];
6978 thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
6979
6980 for (int ib = ix; ib < nb; ib += 4) {
6981 float4 sumy = {0.f, 0.f, 0.f, 0.f};
6982
6983 for (short i = 0; i < 8; ++i) {
6984 yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
6985 yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
6986 yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
6987 yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
6988 }
6989
6990 device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq;
6991 device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
6992 device const half * dh = &x[ib].d;
6993
6994 for (short row = 0; row < nr0; row++) {
6995 sc16[0] = sc[0] & kmask1;
6996 sc16[1] = sc[2] & kmask1;
6997 sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
6998 sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
6999
7000 device const uint16_t * q2 = q1 + 32;
7001
7002 float4 acc1 = {0.f, 0.f, 0.f, 0.f};
7003 float4 acc2 = {0.f, 0.f, 0.f, 0.f};
7004
7005 FOR_UNROLL (short i = 0; i < 4; ++i) {
7006 acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
7007 acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
7008 acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
7009 acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
7010 acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
7011 acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
7012 acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
7013 acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
7014 }
7015
7016 sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
7017 (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
7018 (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
7019 (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
7020 dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
7021
7022 q1 += args.nb01/2;
7023 sc += args.nb01/2;
7024 dh += args.nb01/2;
7025 }
7026
7027 y4 += 4 * QK_K;
7028 }
7029
7030 device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
7031
7032 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7033 float sum_all = simd_sum(sumf[row]);
7034 if (tiisg == 0) {
7035 dst_f32[first_row + row] = sum_all;
7036 }
7037 }
7038}
7039
7040[[host_name("kernel_mul_mv_q4_K_f32")]]
7041kernel void kernel_mul_mv_q4_K_f32(
7042 constant ggml_metal_kargs_mul_mv & args,
7043 device const char * src0,
7044 device const char * src1,
7045 device char * dst,
7046 uint3 tgpig[[threadgroup_position_in_grid]],
7047 ushort tiisg[[thread_index_in_simdgroup]],
7048 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7049
7050 kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7051}
7052
7053template<int nr0, typename args_t>
7054void kernel_mul_mv_q5_K_f32_impl(
7055 args_t args,
7056 device const char * src0,
7057 device const char * src1,
7058 device char * dst,
7059 threadgroup char * shmem,
7060 uint3 tgpig,
7061 ushort tiisg,
7062 ushort sgitg) {
7063 const short NSG = FC_mul_mv_nsg;
7064
7065 const int nb = args.ne00/QK_K;
7066
7067 const int r0 = tgpig.x;
7068 const int r1 = tgpig.y;
7069 const int im = tgpig.z;
7070
7071 const int first_row = (r0 * NSG + sgitg) * nr0;
7072
7073 const uint i12 = im%args.ne12;
7074 const uint i13 = im/args.ne12;
7075
7076 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7077 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7078
7079 device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
7080 device const float * yy = (device const float *) (src1 + offset1);
7081
7082 float sumf[nr0]={0.f};
7083
7084 float yl[16], yh[16];
7085
7086 constexpr uint16_t kmask1 = 0x3f3f;
7087 constexpr uint16_t kmask2 = 0x0f0f;
7088 constexpr uint16_t kmask3 = 0xc0c0;
7089
7090 const short tid = tiisg/4;
7091 const short ix = tiisg%4;
7092 const short iq = tid/4;
7093 const short ir = tid%4;
7094
7095 const short l0 = 8*ir;
7096 const short q_offset = 32*iq + l0;
7097 const short y_offset = 64*iq + l0;
7098
7099 const uint8_t hm1 = 1u << (2*iq);
7100 const uint8_t hm2 = hm1 << 1;
7101 const uint8_t hm3 = hm1 << 4;
7102 const uint8_t hm4 = hm2 << 4;
7103
7104 uint16_t sc16[4];
7105 thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
7106
7107 device const float * y1 = yy + ix*QK_K + y_offset;
7108
7109 for (int i = ix; i < nb; i += 4) {
7110 device const uint8_t * q1 = x[i].qs + q_offset;
7111 device const uint8_t * qh = x[i].qh + l0;
7112 device const half * dh = &x[i].d;
7113 device const uint16_t * a = (device const uint16_t *)x[i].scales + iq;
7114
7115 device const float * y2 = y1 + 128;
7116 float4 sumy = {0.f, 0.f, 0.f, 0.f};
7117 for (short l = 0; l < 8; ++l) {
7118 yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
7119 yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
7120 yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
7121 yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
7122 }
7123
7124 for (short row = 0; row < nr0; ++row) {
7125 device const uint8_t * q2 = q1 + 64;
7126
7127 sc16[0] = a[0] & kmask1;
7128 sc16[1] = a[2] & kmask1;
7129 sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
7130 sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
7131
7132 float4 acc1 = {0.f};
7133 float4 acc2 = {0.f};
7134 FOR_UNROLL (short l = 0; l < 8; ++l) {
7135 uint8_t h = qh[l];
7136 acc1[0] += yl[l+0] * (q1[l] & 0x0F);
7137 acc1[1] += yl[l+8] * (q1[l] & 0xF0);
7138 acc1[2] += yh[l+0] * (q2[l] & 0x0F);
7139 acc1[3] += yh[l+8] * (q2[l] & 0xF0);
7140 acc2[0] += h & hm1 ? yl[l+0] : 0.f;
7141 acc2[1] += h & hm2 ? yl[l+8] : 0.f;
7142 acc2[2] += h & hm3 ? yh[l+0] : 0.f;
7143 acc2[3] += h & hm4 ? yh[l+8] : 0.f;
7144 }
7145
7146 sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
7147 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
7148 sc8[4] * (acc1[2] + 16.f*acc2[2]) +
7149 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
7150 dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
7151
7152 q1 += args.nb01;
7153 qh += args.nb01;
7154 dh += args.nb01/2;
7155 a += args.nb01/2;
7156 }
7157
7158 y1 += 4 * QK_K;
7159 }
7160
7161 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7162
7163 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7164 const float tot = simd_sum(sumf[row]);
7165 if (tiisg == 0) {
7166 dst_f32[first_row + row] = tot;
7167 }
7168 }
7169}
7170
7171[[host_name("kernel_mul_mv_q5_K_f32")]]
7172kernel void kernel_mul_mv_q5_K_f32(
7173 constant ggml_metal_kargs_mul_mv & args,
7174 device const char * src0,
7175 device const char * src1,
7176 device char * dst,
7177 uint3 tgpig[[threadgroup_position_in_grid]],
7178 ushort tiisg[[thread_index_in_simdgroup]],
7179 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7180
7181 kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7182}
7183
7184template<int nr0, typename args_t>
7185void kernel_mul_mv_q6_K_f32_impl(
7186 args_t args,
7187 device const char * src0,
7188 device const char * src1,
7189 device char * dst,
7190 threadgroup char * shmem,
7191 uint3 tgpig,
7192 ushort tiisg,
7193 ushort sgitg) {
7194 const short NSG = FC_mul_mv_nsg;
7195
7196 constexpr uint8_t kmask1 = 0x03;
7197 constexpr uint8_t kmask2 = 0x0C;
7198 constexpr uint8_t kmask3 = 0x30;
7199 constexpr uint8_t kmask4 = 0xC0;
7200
7201 const int nb = args.ne00/QK_K;
7202
7203 const int r0 = tgpig.x;
7204 const int r1 = tgpig.y;
7205 const int im = tgpig.z;
7206
7207 const int first_row = (r0 * NSG + sgitg) * nr0;
7208
7209 const uint i12 = im%args.ne12;
7210 const uint i13 = im/args.ne12;
7211
7212 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7213 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7214
7215 device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
7216 device const float * yy = (device const float *) (src1 + offset1);
7217
7218 float sumf[nr0] = { 0.f };
7219
7220 float yl[16];
7221
7222 const short tid = tiisg/2;
7223 const short ix = tiisg%2;
7224 const short ip = tid/8; // 0 or 1
7225 const short il = tid%8;
7226 const short l0 = 4*il;
7227 const short is = 8*ip + l0/16;
7228
7229 const short y_offset = 128*ip + l0;
7230 const short q_offset_l = 64*ip + l0;
7231 const short q_offset_h = 32*ip + l0;
7232
7233 for (int i = ix; i < nb; i += 2) {
7234 device const uint8_t * q1 = x[i].ql + q_offset_l;
7235 device const uint8_t * q2 = q1 + 32;
7236 device const uint8_t * qh = x[i].qh + q_offset_h;
7237 device const int8_t * sc = x[i].scales + is;
7238 device const half * dh = &x[i].d;
7239
7240 device const float * y = yy + i * QK_K + y_offset;
7241
7242 for (short l = 0; l < 4; ++l) {
7243 yl[4*l + 0] = y[l + 0];
7244 yl[4*l + 1] = y[l + 32];
7245 yl[4*l + 2] = y[l + 64];
7246 yl[4*l + 3] = y[l + 96];
7247 }
7248
7249 for (short row = 0; row < nr0; ++row) {
7250 float4 sums = {0.f, 0.f, 0.f, 0.f};
7251
7252 FOR_UNROLL (short l = 0; l < 4; ++l) {
7253 sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
7254 sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
7255 sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
7256 sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
7257 }
7258
7259 sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
7260
7261 q1 += args.nb01;
7262 q2 += args.nb01;
7263 qh += args.nb01;
7264 sc += args.nb01;
7265 dh += args.nb01/2;
7266 }
7267 }
7268
7269 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7270
7271 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7272 float sum_all = simd_sum(sumf[row]);
7273 if (tiisg == 0) {
7274 dst_f32[first_row + row] = sum_all;
7275 }
7276 }
7277}
7278
7279[[host_name("kernel_mul_mv_q6_K_f32")]]
7280kernel void kernel_mul_mv_q6_K_f32(
7281 constant ggml_metal_kargs_mul_mv & args,
7282 device const char * src0,
7283 device const char * src1,
7284 device char * dst,
7285 uint3 tgpig[[threadgroup_position_in_grid]],
7286 ushort tiisg[[thread_index_in_simdgroup]],
7287 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7288
7289 kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7290}
7291
7292// ======================= "True" 2-bit
7293
7294template<int nr0, typename args_t>
7295void kernel_mul_mv_iq2_xxs_f32_impl(
7296 args_t args,
7297 device const char * src0,
7298 device const char * src1,
7299 device char * dst,
7300 threadgroup char * shmem,
7301 uint3 tgpig,
7302 ushort tiisg,
7303 ushort sgitg) {
7304 const short NSG = FC_mul_mv_nsg;
7305
7306 const int nb = args.ne00/QK_K;
7307
7308 const int r0 = tgpig.x;
7309 const int r1 = tgpig.y;
7310 const int im = tgpig.z;
7311
7312 const int first_row = (r0 * NSG + sgitg) * nr0;
7313
7314 const uint i12 = im%args.ne12;
7315 const uint i13 = im/args.ne12;
7316
7317 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7318 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7319
7320 device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
7321 device const float * y = (device const float *) (src1 + offset1);
7322
7323 float yl[32];
7324 float sumf[nr0]={0.f};
7325
7326 const int nb32 = nb * (QK_K / 32);
7327
7328 threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
7329 threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
7330 {
7331 int nval = 4;
7332 int pos = (32*sgitg + tiisg)*nval;
7333 for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
7334 nval = 2;
7335 pos = (32*sgitg + tiisg)*nval;
7336 for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
7337 threadgroup_barrier(mem_flags::mem_threadgroup);
7338 }
7339
7340 const int ix = tiisg;
7341
7342 device const float * y4 = y + 32 * ix;
7343
7344 for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7345 for (short i = 0; i < 32; ++i) {
7346 yl[i] = y4[i];
7347 }
7348
7349 const int ibl = ib32 / (QK_K / 32);
7350 const int ib = ib32 % (QK_K / 32);
7351
7352 device const block_iq2_xxs * xr = x + ibl;
7353 device const uint16_t * q2 = xr->qs + 4 * ib;
7354 device const half * dh = &xr->d;
7355
7356 for (short row = 0; row < nr0; row++) {
7357 const float db = dh[0];
7358 device const uint8_t * aux8 = (device const uint8_t *)q2;
7359 const uint32_t aux32 = q2[2] | (q2[3] << 16);
7360 const float d = db * (0.5f + (aux32 >> 28));
7361
7362 float sum = 0;
7363 for (short l = 0; l < 4; ++l) {
7364 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
7365 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
7366 for (short j = 0; j < 8; ++j) {
7367 sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
7368 }
7369 }
7370 sumf[row] += d * sum;
7371
7372 dh += args.nb01/2;
7373 q2 += args.nb01/2;
7374 }
7375
7376 y4 += 32 * 32;
7377 }
7378
7379 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7380
7381 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7382 float sum_all = simd_sum(sumf[row]);
7383 if (tiisg == 0) {
7384 dst_f32[first_row + row] = sum_all * 0.25f;
7385 }
7386 }
7387}
7388
7389[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
7390kernel void kernel_mul_mv_iq2_xxs_f32(
7391 constant ggml_metal_kargs_mul_mv & args,
7392 device const char * src0,
7393 device const char * src1,
7394 device char * dst,
7395 threadgroup char * shmem [[threadgroup(0)]],
7396 uint3 tgpig[[threadgroup_position_in_grid]],
7397 ushort tiisg[[thread_index_in_simdgroup]],
7398 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7399 kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7400}
7401
7402template<int nr0, typename args_t>
7403void kernel_mul_mv_iq2_xs_f32_impl(
7404 args_t args,
7405 device const char * src0,
7406 device const char * src1,
7407 device char * dst,
7408 threadgroup char * shmem,
7409 uint3 tgpig,
7410 ushort tiisg,
7411 ushort sgitg) {
7412 const short NSG = FC_mul_mv_nsg;
7413
7414 const int nb = args.ne00/QK_K;
7415
7416 const int r0 = tgpig.x;
7417 const int r1 = tgpig.y;
7418 const int im = tgpig.z;
7419
7420 const int first_row = (r0 * NSG + sgitg) * nr0;
7421
7422 const uint i12 = im%args.ne12;
7423 const uint i13 = im/args.ne12;
7424
7425 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7426 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7427
7428 device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
7429 device const float * y = (device const float *) (src1 + offset1);
7430
7431 float yl[32];
7432 float sumf[nr0]={0.f};
7433
7434 const int nb32 = nb * (QK_K / 32);
7435
7436 threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
7437 threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512);
7438 {
7439 int nval = 8;
7440 int pos = (32*sgitg + tiisg)*nval;
7441 for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
7442 nval = 2;
7443 pos = (32*sgitg + tiisg)*nval;
7444 for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
7445 threadgroup_barrier(mem_flags::mem_threadgroup);
7446 }
7447
7448 const int ix = tiisg;
7449
7450 device const float * y4 = y + 32 * ix;
7451
7452 for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7453 for (short i = 0; i < 32; ++i) {
7454 yl[i] = y4[i];
7455 }
7456
7457 const int ibl = ib32 / (QK_K / 32);
7458 const int ib = ib32 % (QK_K / 32);
7459
7460 device const block_iq2_xs * xr = x + ibl;
7461 device const uint16_t * q2 = xr->qs + 4 * ib;
7462 device const uint8_t * sc = xr->scales + ib;
7463 device const half * dh = &xr->d;
7464
7465 for (short row = 0; row < nr0; row++) {
7466 const float db = dh[0];
7467 const uint8_t ls1 = sc[0] & 0xf;
7468 const uint8_t ls2 = sc[0] >> 4;
7469 const float d1 = db * (0.5f + ls1);
7470 const float d2 = db * (0.5f + ls2);
7471
7472 float sum1 = 0, sum2 = 0;
7473 for (short l = 0; l < 2; ++l) {
7474 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
7475 const uint8_t signs = ssigns[(q2[l] >> 9)];
7476 for (short j = 0; j < 8; ++j) {
7477 sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
7478 }
7479 }
7480 for (short l = 2; l < 4; ++l) {
7481 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
7482 const uint8_t signs = ssigns[(q2[l] >> 9)];
7483 for (short j = 0; j < 8; ++j) {
7484 sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
7485 }
7486 }
7487 sumf[row] += d1 * sum1 + d2 * sum2;
7488
7489 dh += args.nb01/2;
7490 q2 += args.nb01/2;
7491 sc += args.nb01;
7492 }
7493
7494 y4 += 32 * 32;
7495 }
7496
7497 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7498
7499 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7500 float sum_all = simd_sum(sumf[row]);
7501 if (tiisg == 0) {
7502 dst_f32[first_row + row] = sum_all * 0.25f;
7503 }
7504 }
7505}
7506
7507[[host_name("kernel_mul_mv_iq2_xs_f32")]]
7508kernel void kernel_mul_mv_iq2_xs_f32(
7509 constant ggml_metal_kargs_mul_mv & args,
7510 device const char * src0,
7511 device const char * src1,
7512 device char * dst,
7513 threadgroup char * shmem [[threadgroup(0)]],
7514 uint3 tgpig[[threadgroup_position_in_grid]],
7515 ushort tiisg[[thread_index_in_simdgroup]],
7516 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7517
7518 kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7519}
7520
7521template<int nr0, typename args_t>
7522void kernel_mul_mv_iq3_xxs_f32_impl(
7523 args_t args,
7524 device const char * src0,
7525 device const char * src1,
7526 device char * dst,
7527 threadgroup char * shmem,
7528 uint3 tgpig,
7529 ushort tiisg,
7530 ushort sgitg) {
7531 const short NSG = FC_mul_mv_nsg;
7532
7533 const int nb = args.ne00/QK_K;
7534
7535 const int r0 = tgpig.x;
7536 const int r1 = tgpig.y;
7537 const int im = tgpig.z;
7538
7539 const int first_row = (r0 * NSG + sgitg) * nr0;
7540
7541 const uint i12 = im%args.ne12;
7542 const uint i13 = im/args.ne12;
7543
7544 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7545 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7546
7547 device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
7548 device const float * y = (device const float *) (src1 + offset1);
7549
7550 float yl[32];
7551 float sumf[nr0]={0.f};
7552
7553 const int nb32 = nb * (QK_K / 32);
7554
7555 threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
7556 threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
7557 {
7558 int nval = 4;
7559 int pos = (32*sgitg + tiisg)*nval;
7560 for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
7561 nval = 2;
7562 pos = (32*sgitg + tiisg)*nval;
7563 for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
7564 threadgroup_barrier(mem_flags::mem_threadgroup);
7565 }
7566
7567 const int ix = tiisg;
7568
7569 device const float * y4 = y + 32 * ix;
7570
7571 for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7572 for (short i = 0; i < 32; ++i) {
7573 yl[i] = y4[i];
7574 }
7575
7576 const int ibl = ib32 / (QK_K / 32);
7577 const int ib = ib32 % (QK_K / 32);
7578
7579 device const block_iq3_xxs * xr = x + ibl;
7580 device const uint8_t * q3 = xr->qs + 8 * ib;
7581 device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
7582 device const half * dh = &xr->d;
7583
7584 for (short row = 0; row < nr0; row++) {
7585 const float db = dh[0];
7586 const uint32_t aux32 = gas[0] | (gas[1] << 16);
7587 const float d = db * (0.5f + (aux32 >> 28));
7588
7589 float2 sum = {0};
7590 for (short l = 0; l < 4; ++l) {
7591 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
7592 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
7593 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
7594 for (short j = 0; j < 4; ++j) {
7595 sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
7596 sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
7597 }
7598 }
7599 sumf[row] += d * (sum[0] + sum[1]);
7600
7601 dh += args.nb01/2;
7602 q3 += args.nb01;
7603 gas += args.nb01/2;
7604 }
7605
7606 y4 += 32 * 32;
7607 }
7608
7609 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7610
7611 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7612 float sum_all = simd_sum(sumf[row]);
7613 if (tiisg == 0) {
7614 dst_f32[first_row + row] = sum_all * 0.5f;
7615 }
7616 }
7617}
7618
7619[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
7620kernel void kernel_mul_mv_iq3_xxs_f32(
7621 constant ggml_metal_kargs_mul_mv & args,
7622 device const char * src0,
7623 device const char * src1,
7624 device char * dst,
7625 threadgroup char * shmem [[threadgroup(0)]],
7626 uint3 tgpig[[threadgroup_position_in_grid]],
7627 ushort tiisg[[thread_index_in_simdgroup]],
7628 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7629
7630 kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7631}
7632
7633template<int nr0, typename args_t>
7634void kernel_mul_mv_iq3_s_f32_impl(
7635 args_t args,
7636 device const char * src0,
7637 device const char * src1,
7638 device char * dst,
7639 threadgroup char * shmem,
7640 uint3 tgpig,
7641 ushort tiisg,
7642 ushort sgitg) {
7643 const short NSG = FC_mul_mv_nsg;
7644
7645 const int nb = args.ne00/QK_K;
7646
7647 const int r0 = tgpig.x;
7648 const int r1 = tgpig.y;
7649 const int im = tgpig.z;
7650
7651 const int first_row = (r0 * NSG + sgitg) * nr0;
7652
7653 const uint i12 = im%args.ne12;
7654 const uint i13 = im/args.ne12;
7655
7656 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7657 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7658
7659 device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
7660 device const float * y = (device const float *) (src1 + offset1);
7661
7662 float yl[32];
7663 float sumf[nr0]={0.f};
7664
7665 const int nb32 = nb * (QK_K / 32);
7666
7667 threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
7668 {
7669 int nval = 8;
7670 int pos = (32*sgitg + tiisg)*nval;
7671 for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
7672 threadgroup_barrier(mem_flags::mem_threadgroup);
7673 }
7674
7675 const int ix = tiisg;
7676
7677 device const float * y4 = y + 32 * ix;
7678
7679 for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7680 for (short i = 0; i < 32; ++i) {
7681 yl[i] = y4[i];
7682 }
7683
7684 const int ibl = ib32 / (QK_K / 32);
7685 const int ib = ib32 % (QK_K / 32);
7686
7687 device const block_iq3_s * xr = x + ibl;
7688 device const uint8_t * qs = xr->qs + 8 * ib;
7689 device const uint8_t * qh = xr->qh + ib;
7690 device const uint8_t * sc = xr->scales + (ib/2);
7691 device const uint8_t * signs = xr->signs + 4 * ib;
7692 device const half * dh = &xr->d;
7693
7694 for (short row = 0; row < nr0; row++) {
7695 const float db = dh[0];
7696 const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
7697
7698 float2 sum = {0};
7699 for (short l = 0; l < 4; ++l) {
7700 const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
7701 const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
7702 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
7703 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
7704 for (short j = 0; j < 4; ++j) {
7705 sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
7706 sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
7707 }
7708 }
7709 sumf[row] += d * (sum[0] + sum[1]);
7710
7711 dh += args.nb01/2;
7712 qs += args.nb01;
7713 qh += args.nb01;
7714 sc += args.nb01;
7715 signs += args.nb01;
7716 }
7717
7718 y4 += 32 * 32;
7719 }
7720
7721 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7722
7723 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7724 float sum_all = simd_sum(sumf[row]);
7725 if (tiisg == 0) {
7726 dst_f32[first_row + row] = sum_all;
7727 }
7728 }
7729}
7730
7731[[host_name("kernel_mul_mv_iq3_s_f32")]]
7732kernel void kernel_mul_mv_iq3_s_f32(
7733 constant ggml_metal_kargs_mul_mv & args,
7734 device const char * src0,
7735 device const char * src1,
7736 device char * dst,
7737 threadgroup char * shmem [[threadgroup(0)]],
7738 uint3 tgpig[[threadgroup_position_in_grid]],
7739 ushort tiisg[[thread_index_in_simdgroup]],
7740 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7741
7742 kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7743}
7744
7745template<int nr0, typename args_t>
7746void kernel_mul_mv_iq2_s_f32_impl(
7747 args_t args,
7748 device const char * src0,
7749 device const char * src1,
7750 device char * dst,
7751 threadgroup char * shmem,
7752 uint3 tgpig,
7753 ushort tiisg,
7754 ushort sgitg) {
7755 const short NSG = FC_mul_mv_nsg;
7756
7757 const int nb = args.ne00/QK_K;
7758
7759 const int r0 = tgpig.x;
7760 const int r1 = tgpig.y;
7761 const int im = tgpig.z;
7762
7763 const int first_row = (r0 * NSG + sgitg) * nr0;
7764
7765 const uint i12 = im%args.ne12;
7766 const uint i13 = im/args.ne12;
7767
7768 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7769 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7770
7771 device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
7772 device const float * y = (device const float *) (src1 + offset1);
7773
7774 float yl[32];
7775 float sumf[nr0]={0.f};
7776
7777 const int nb32 = nb * (QK_K / 32);
7778
7779 //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
7780 //{
7781 // int nval = 32;
7782 // int pos = (32*sgitg + tiisg)*nval;
7783 // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
7784 // threadgroup_barrier(mem_flags::mem_threadgroup);
7785 //}
7786
7787 const short ix = tiisg;
7788
7789 device const float * y4 = y + 32 * ix;
7790
7791 for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7792 for (short i = 0; i < 32; ++i) {
7793 yl[i] = y4[i];
7794 }
7795
7796 const int ibl = ib32 / (QK_K / 32);
7797 const int ib = ib32 % (QK_K / 32);
7798
7799 device const block_iq2_s * xr = x + ibl;
7800 device const uint8_t * qs = xr->qs + 4 * ib;
7801 device const uint8_t * qh = xr->qh + ib;
7802 device const uint8_t * sc = xr->scales + ib;
7803 device const uint8_t * signs = qs + QK_K/8;
7804 device const half * dh = &xr->d;
7805
7806 for (short row = 0; row < nr0; row++) {
7807 const float db = dh[0];
7808 const float d1 = db * (0.5f + (sc[0] & 0xf));
7809 const float d2 = db * (0.5f + (sc[0] >> 4));
7810
7811 float2 sum = {0};
7812 for (short l = 0; l < 2; ++l) {
7813 //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
7814 //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
7815 constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
7816 constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
7817 for (short j = 0; j < 8; ++j) {
7818 sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
7819 sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
7820 }
7821 }
7822 sumf[row] += d1 * sum[0] + d2 * sum[1];
7823
7824 dh += args.nb01/2;
7825 qs += args.nb01;
7826 qh += args.nb01;
7827 sc += args.nb01;
7828 signs += args.nb01;
7829 }
7830
7831 y4 += 32 * 32;
7832 }
7833
7834 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7835
7836 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7837 float sum_all = simd_sum(sumf[row]);
7838 if (tiisg == 0) {
7839 dst_f32[first_row + row] = sum_all * 0.25f;
7840 }
7841 }
7842}
7843
7844[[host_name("kernel_mul_mv_iq2_s_f32")]]
7845kernel void kernel_mul_mv_iq2_s_f32(
7846 constant ggml_metal_kargs_mul_mv & args,
7847 device const char * src0,
7848 device const char * src1,
7849 device char * dst,
7850 threadgroup char * shmem [[threadgroup(0)]],
7851 uint3 tgpig[[threadgroup_position_in_grid]],
7852 ushort tiisg[[thread_index_in_simdgroup]],
7853 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7854
7855 kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
7856}
7857
7858template<int nr0, typename args_t>
7859void kernel_mul_mv_iq1_s_f32_impl(
7860 args_t args,
7861 device const char * src0,
7862 device const char * src1,
7863 device char * dst,
7864 threadgroup char * shmem,
7865 uint3 tgpig,
7866 ushort tiisg,
7867 ushort sgitg) {
7868 const short NSG = FC_mul_mv_nsg;
7869
7870 const int nb = args.ne00/QK_K;
7871
7872 const int r0 = tgpig.x;
7873 const int r1 = tgpig.y;
7874 const int im = tgpig.z;
7875
7876 const int first_row = (r0 * NSG + sgitg) * nr0;
7877
7878 const uint i12 = im%args.ne12;
7879 const uint i13 = im/args.ne12;
7880
7881 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7882 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7883
7884 device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
7885 device const float * y = (device const float *) (src1 + offset1);
7886
7887 float yl[32];
7888 float sumf[nr0]={0.f};
7889
7890 const int nb32 = nb * (QK_K / 32);
7891
7892 const short ix = tiisg;
7893
7894 device const float * y4 = y + 32 * ix;
7895
7896 for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7897 float sumy = 0;
7898 for (short i = 0; i < 32; ++i) {
7899 yl[i] = y4[i];
7900 sumy += yl[i];
7901 }
7902
7903 const int ibl = ib32 / (QK_K / 32);
7904 const int ib = ib32 % (QK_K / 32);
7905
7906 device const block_iq1_s * xr = x + ibl;
7907 device const uint8_t * qs = xr->qs + 4 * ib;
7908 device const uint16_t * qh = xr->qh + ib;
7909 device const half * dh = &xr->d;
7910
7911 for (short row = 0; row < nr0; row++) {
7912 constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
7913 constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
7914 constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
7915 constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
7916
7917 float sum = 0;
7918 for (short j = 0; j < 4; ++j) {
7919 sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
7920 + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
7921 + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
7922 + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
7923 }
7924 sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
7925
7926 dh += args.nb01/2;
7927 qs += args.nb01;
7928 qh += args.nb01/2;
7929 }
7930
7931 y4 += 32 * 32;
7932 }
7933
7934 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
7935
7936 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
7937 float sum_all = simd_sum(sumf[row]);
7938 if (tiisg == 0) {
7939 dst_f32[first_row + row] = sum_all;
7940 }
7941 }
7942}
7943
7944[[host_name("kernel_mul_mv_iq1_s_f32")]]
7945kernel void kernel_mul_mv_iq1_s_f32(
7946 constant ggml_metal_kargs_mul_mv & args,
7947 device const char * src0,
7948 device const char * src1,
7949 device char * dst,
7950 uint3 tgpig[[threadgroup_position_in_grid]],
7951 ushort tiisg[[thread_index_in_simdgroup]],
7952 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
7953
7954 kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
7955}
7956
7957template<int nr0, typename args_t>
7958void kernel_mul_mv_iq1_m_f32_impl(
7959 args_t args,
7960 device const char * src0,
7961 device const char * src1,
7962 device char * dst,
7963 threadgroup char * shmem,
7964 uint3 tgpig,
7965 ushort tiisg,
7966 ushort sgitg) {
7967 const short NSG = FC_mul_mv_nsg;
7968
7969 const int nb = args.ne00/QK_K;
7970
7971 const int r0 = tgpig.x;
7972 const int r1 = tgpig.y;
7973 const int im = tgpig.z;
7974
7975 const int first_row = (r0 * NSG + sgitg) * nr0;
7976
7977 const uint i12 = im%args.ne12;
7978 const uint i13 = im/args.ne12;
7979
7980 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
7981 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
7982
7983 device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
7984 device const float * y = (device const float *) (src1 + offset1);
7985
7986 float yl[32];
7987 float sumf[nr0]={0.f};
7988
7989 const int nb32 = nb * (QK_K / 32);
7990
7991 const short ix = tiisg;
7992
7993 device const float * y4 = y + 32 * ix;
7994
7995 iq1m_scale_t scale;
7996
7997 for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
7998 float4 sumy = {0.f};
7999 for (short i = 0; i < 8; ++i) {
8000 yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
8001 yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
8002 yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
8003 yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
8004 }
8005
8006 const int ibl = ib32 / (QK_K / 32);
8007 const int ib = ib32 % (QK_K / 32);
8008
8009 device const block_iq1_m * xr = x + ibl;
8010 device const uint8_t * qs = xr->qs + 4 * ib;
8011 device const uint8_t * qh = xr->qh + 2 * ib;
8012 device const uint16_t * sc = (device const uint16_t *)xr->scales;
8013
8014 for (short row = 0; row < nr0; row++) {
8015 scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
8016
8017 constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
8018 constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
8019 constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
8020 constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
8021
8022 float2 sum = {0.f};
8023 for (short j = 0; j < 4; ++j) {
8024 sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
8025 + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
8026 sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
8027 + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
8028 }
8029 const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
8030 const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
8031
8032 sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
8033 (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
8034
8035 sc += args.nb01/2;
8036 qs += args.nb01;
8037 qh += args.nb01;
8038 }
8039
8040 y4 += 32 * 32;
8041 }
8042
8043 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8044
8045 for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
8046 float sum_all = simd_sum(sumf[row]);
8047 if (tiisg == 0) {
8048 dst_f32[first_row + row] = sum_all;
8049 }
8050 }
8051}
8052
8053[[host_name("kernel_mul_mv_iq1_m_f32")]]
8054kernel void kernel_mul_mv_iq1_m_f32(
8055 constant ggml_metal_kargs_mul_mv & args,
8056 device const char * src0,
8057 device const char * src1,
8058 device char * dst,
8059 uint3 tgpig[[threadgroup_position_in_grid]],
8060 ushort tiisg[[thread_index_in_simdgroup]],
8061 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8062
8063 kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
8064}
8065
8066template<int NR0, typename args_t>
8067void kernel_mul_mv_iq4_nl_f32_impl(
8068 args_t args,
8069 device const char * src0,
8070 device const char * src1,
8071 device char * dst,
8072 threadgroup char * shmem,
8073 uint3 tgpig,
8074 ushort tiisg,
8075 ushort sgitg) {
8076 const short NSG = FC_mul_mv_nsg;
8077
8078 threadgroup float * shmem_f32 = (threadgroup float *) shmem;
8079
8080 const int r0 = tgpig.x;
8081 const int r1 = tgpig.y;
8082 const int im = tgpig.z;
8083
8084 const int first_row = (r0 * NSG + sgitg) * NR0;
8085
8086 const uint i12 = im%args.ne12;
8087 const uint i13 = im/args.ne12;
8088
8089 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8090 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8091
8092 device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
8093 device const float * y = (device const float *) (src1 + offset1);
8094
8095 const int nb = args.ne00/QK4_NL;
8096 const int ns01 = args.nb01/args.nb00;
8097
8098 const short ix = tiisg/2; // 0...15
8099 const short it = tiisg%2; // 0 or 1
8100
8101 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
8102 threadgroup_barrier(mem_flags::mem_threadgroup);
8103
8104 float4 yl[4];
8105 float sumf[NR0]={0.f};
8106
8107 device const float * yb = y + ix*QK4_NL + it*8;
8108
8109 uint32_t aux32[2];
8110 thread const uint8_t * q8 = (thread const uint8_t *)aux32;
8111
8112 float4 qf1, qf2;
8113
8114 // [TAG_MUL_MV_WEIRD]
8115 for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
8116 device const float4 * y4 = (device const float4 *)yb;
8117 yl[0] = y4[0];
8118 yl[1] = y4[4];
8119 yl[2] = y4[1];
8120 yl[3] = y4[5];
8121
8122 for (short row = 0; row < NR0; row++) {
8123 device const block_iq4_nl & xb = x[row*ns01 + ib];
8124 device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
8125
8126 float4 acc1 = {0.f}, acc2 = {0.f};
8127
8128 aux32[0] = q4[0] | (q4[1] << 16);
8129 aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
8130 aux32[0] &= 0x0f0f0f0f;
8131 qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8132 qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8133 acc1 += yl[0] * qf1;
8134 acc2 += yl[1] * qf2;
8135
8136 aux32[0] = q4[2] | (q4[3] << 16);
8137 aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
8138 aux32[0] &= 0x0f0f0f0f;
8139 qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8140 qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8141 acc1 += yl[2] * qf1;
8142 acc2 += yl[3] * qf2;
8143
8144 acc1 += acc2;
8145
8146 sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
8147 }
8148
8149 yb += 16 * QK4_NL;
8150 }
8151
8152 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8153
8154 for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
8155 float sum_all = simd_sum(sumf[row]);
8156 if (tiisg == 0) {
8157 dst_f32[first_row + row] = sum_all;
8158 }
8159 }
8160}
8161
8162[[host_name("kernel_mul_mv_iq4_nl_f32")]]
8163kernel void kernel_mul_mv_iq4_nl_f32(
8164 constant ggml_metal_kargs_mul_mv & args,
8165 device const char * src0,
8166 device const char * src1,
8167 device char * dst,
8168 threadgroup char * shmem [[threadgroup(0)]],
8169 uint3 tgpig[[threadgroup_position_in_grid]],
8170 ushort tiisg[[thread_index_in_simdgroup]],
8171 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8172
8173 kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
8174}
8175
8176template<int NR0, typename args_t>
8177void kernel_mul_mv_iq4_xs_f32_impl(
8178 args_t args,
8179 device const char * src0,
8180 device const char * src1,
8181 device char * dst,
8182 threadgroup char * shmem,
8183 uint3 tgpig,
8184 ushort tiisg,
8185 ushort sgitg) {
8186 const short NSG = FC_mul_mv_nsg;
8187
8188 threadgroup float * shmem_f32 = (threadgroup float *) shmem;
8189
8190 const int r0 = tgpig.x;
8191 const int r1 = tgpig.y;
8192 const int im = tgpig.z;
8193 const int first_row = (r0 * NSG + sgitg) * NR0;
8194
8195 const uint i12 = im%args.ne12;
8196 const uint i13 = im/args.ne12;
8197
8198 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8199 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8200
8201 device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
8202 device const float * y = (device const float *) (src1 + offset1);
8203
8204 const int nb = args.ne00/QK_K;
8205 const int ns01 = args.nb01/args.nb00;
8206
8207 const short ix = tiisg/16; // 0 or 1
8208 const short it = tiisg%16; // 0...15
8209 const short ib = it/2;
8210 const short il = it%2;
8211
8212 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
8213 threadgroup_barrier(mem_flags::mem_threadgroup);
8214
8215 float4 yl[4];
8216 float sumf[NR0]={0.f};
8217
8218 device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
8219
8220 uint32_t aux32[2];
8221 thread const uint8_t * q8 = (thread const uint8_t *)aux32;
8222
8223 float4 qf1, qf2;
8224
8225 // [TAG_MUL_MV_WEIRD]
8226 for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) {
8227 device const float4 * y4 = (device const float4 *)yb;
8228 yl[0] = y4[0];
8229 yl[1] = y4[4];
8230 yl[2] = y4[1];
8231 yl[3] = y4[5];
8232
8233 for (short row = 0; row < NR0; ++row) {
8234 device const block_iq4_xs & xb = x[row*ns01 + ibl];
8235 device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
8236
8237 float4 acc1 = {0.f}, acc2 = {0.f};
8238
8239 aux32[0] = (q4[0] ) & 0x0f0f0f0f;
8240 aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
8241 qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8242 qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8243 acc1 += yl[0] * qf1;
8244 acc2 += yl[1] * qf2;
8245
8246 aux32[0] = (q4[1] ) & 0x0f0f0f0f;
8247 aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
8248 qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
8249 qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
8250 acc1 += yl[2] * qf1;
8251 acc2 += yl[3] * qf2;
8252
8253 acc1 += acc2;
8254
8255 const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
8256 sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
8257 }
8258
8259 yb += 2 * QK_K;
8260 }
8261
8262 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8263
8264 for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
8265 float sum_all = simd_sum(sumf[row]);
8266 if (tiisg == 0) {
8267 dst_f32[first_row + row] = sum_all;
8268 }
8269 }
8270}
8271
8272[[host_name("kernel_mul_mv_iq4_xs_f32")]]
8273kernel void kernel_mul_mv_iq4_xs_f32(
8274 constant ggml_metal_kargs_mul_mv & args,
8275 device const char * src0,
8276 device const char * src1,
8277 device char * dst,
8278 threadgroup char * shmem [[threadgroup(0)]],
8279 uint3 tgpig[[threadgroup_position_in_grid]],
8280 ushort tiisg[[thread_index_in_simdgroup]],
8281 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8282
8283 kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
8284}
8285
8286template<int NR0, typename args_t>
8287void kernel_mul_mv_mxfp4_f32_impl(
8288 args_t args,
8289 device const char * src0,
8290 device const char * src1,
8291 device char * dst,
8292 threadgroup char * shmem,
8293 uint3 tgpig,
8294 ushort tiisg,
8295 ushort sgitg) {
8296 const short NSG = FC_mul_mv_nsg;
8297
8298 threadgroup float * shmem_f32 = (threadgroup float *) shmem;
8299
8300 const int r0 = tgpig.x;
8301 const int r1 = tgpig.y;
8302 const int im = tgpig.z;
8303
8304 const int first_row = (r0 * NSG + sgitg) * NR0;
8305
8306 const uint i12 = im%args.ne12;
8307 const uint i13 = im/args.ne12;
8308
8309 const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8310 const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
8311
8312 device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
8313 device const float * y = (device const float *) (src1 + offset1);
8314
8315 const int nb = args.ne00/QK_MXFP4;
8316 const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors
8317
8318 const short ix = tiisg/2; // 0...15
8319 const short it = tiisg%2; // 0 or 1
8320
8321 shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
8322 threadgroup_barrier(mem_flags::mem_threadgroup);
8323
8324 float4 yl[4];
8325 float sumf[NR0]={0.f};
8326
8327 device const float * yb = y + ix*QK_MXFP4 + it*8;
8328
8329 // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster
8330 // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD]
8331 for (int ib = ix; ib < nb && ib < ns01; ib += 16) {
8332 device const float4 * y4 = (device const float4 *) yb;
8333
8334 yl[0] = y4[0];
8335 yl[1] = y4[4];
8336 yl[2] = y4[1];
8337 yl[3] = y4[5];
8338
8339 FOR_UNROLL (short row = 0; row < NR0; row++) {
8340 device const block_mxfp4 & xb = x[row*ns01 + ib];
8341 device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
8342
8343 float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
8344 float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
8345 float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
8346 float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
8347
8348 acc1 = (acc1 + acc3) + (acc2 + acc4);
8349
8350 sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
8351 }
8352
8353 yb += 16 * QK_MXFP4;
8354 }
8355
8356 device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
8357
8358 for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) {
8359 float sum_all = simd_sum(sumf[row]);
8360 if (tiisg == 0) {
8361 dst_f32[first_row + row] = sum_all;
8362 }
8363 }
8364}
8365
8366[[host_name("kernel_mul_mv_mxfp4_f32")]]
8367kernel void kernel_mul_mv_mxfp4_f32(
8368 constant ggml_metal_kargs_mul_mv & args,
8369 device const char * src0,
8370 device const char * src1,
8371 device char * dst,
8372 threadgroup char * shmem [[threadgroup(0)]],
8373 uint3 tgpig[[threadgroup_position_in_grid]],
8374 ushort tiisg[[thread_index_in_simdgroup]],
8375 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8376
8377 kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
8378}
8379
8380template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
8381kernel void kernel_get_rows_q(
8382 constant ggml_metal_kargs_get_rows & args,
8383 device const void * src0,
8384 device const void * src1,
8385 device void * dst,
8386 uint3 tgpig[[threadgroup_position_in_grid]],
8387 ushort tiitg[[thread_index_in_threadgroup]],
8388 ushort3 ntg [[threads_per_threadgroup]]) {
8389 const int32_t iw0 = tgpig.x/args.ne10;
8390 const int32_t i10 = tgpig.x%args.ne10;
8391 const int32_t i11 = tgpig.y;
8392 const int32_t i12 = tgpig.z;
8393
8394 const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
8395
8396 const int32_t i02 = i11;
8397 const int32_t i03 = i12;
8398
8399 auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
8400 auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
8401
8402 for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
8403 float4x4 temp;
8404 dequantize_func(psrc + ind/nl, ind%nl, temp);
8405 pdst[ind] = temp;
8406
8407 break;
8408 }
8409}
8410
8411template<typename T0, typename T>
8412kernel void kernel_get_rows_f(
8413 constant ggml_metal_kargs_get_rows & args,
8414 device const void * src0,
8415 device const void * src1,
8416 device void * dst,
8417 uint3 tgpig[[threadgroup_position_in_grid]],
8418 ushort tiitg[[thread_index_in_threadgroup]],
8419 ushort3 ntg [[threads_per_threadgroup]]) {
8420 const int32_t iw0 = tgpig.x/args.ne10;
8421 const int32_t i10 = tgpig.x%args.ne10;
8422 const int32_t i11 = tgpig.y;
8423 const int32_t i12 = tgpig.z;
8424
8425 const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
8426
8427 const int32_t i02 = i11;
8428 const int32_t i03 = i12;
8429
8430 auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
8431 auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
8432
8433 for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
8434 pdst[ind] = psrc[ind];
8435
8436 break;
8437 }
8438}
8439
8440template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
8441kernel void kernel_set_rows_q32(
8442 constant ggml_metal_kargs_set_rows & args,
8443 device const void * src0,
8444 device const void * src1,
8445 device float * dst,
8446 uint3 tgpig[[threadgroup_position_in_grid]],
8447 uint tiitg[[thread_index_in_threadgroup]],
8448 uint3 tptg [[threads_per_threadgroup]]) {
8449 const int32_t i03 = tgpig.z;
8450 const int32_t i02 = tgpig.y;
8451
8452 const int32_t i12 = i03%args.ne12;
8453 const int32_t i11 = i02%args.ne11;
8454
8455 const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
8456 if (i01 >= args.ne01) {
8457 return;
8458 }
8459
8460 const int32_t i10 = i01;
8461 const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
8462
8463 device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
8464 const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
8465
8466 for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
8467 quantize_func(src_row + 32*ind, dst_row[ind]);
8468 }
8469}
8470
8471template<typename T, typename TI>
8472kernel void kernel_set_rows_f(
8473 constant ggml_metal_kargs_set_rows & args,
8474 device const void * src0,
8475 device const void * src1,
8476 device float * dst,
8477 uint3 tgpig[[threadgroup_position_in_grid]],
8478 uint tiitg[[thread_index_in_threadgroup]],
8479 uint3 tptg [[threads_per_threadgroup]]) {
8480 const int32_t i03 = tgpig.z;
8481 const int32_t i02 = tgpig.y;
8482
8483 const int32_t i12 = i03%args.ne12;
8484 const int32_t i11 = i02%args.ne11;
8485
8486 const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
8487 if (i01 >= args.ne01) {
8488 return;
8489 }
8490
8491 const int32_t i10 = i01;
8492 const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
8493
8494 device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
8495 const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
8496
8497 for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
8498 dst_row[ind] = (T) src_row[ind];
8499 }
8500}
8501
8502kernel void kernel_diag_f32(
8503 constant ggml_metal_kargs_diag & args,
8504 device const char * src0,
8505 device char * dst,
8506 uint3 tgpig[[threadgroup_position_in_grid]],
8507 ushort tiitg[[thread_index_in_threadgroup]]) {
8508 constexpr short NW = N_SIMDWIDTH;
8509
8510 const int32_t i3 = tgpig.z;
8511 const int32_t i2 = tgpig.y;
8512 const int32_t i1 = tgpig.x;
8513
8514 device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03);
8515 device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3);
8516
8517 for (int i0 = tiitg; i0 < args.ne0; i0 += NW) {
8518 dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f;
8519 }
8520}
8521
8522constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]];
8523constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]];
8524
8525// each block_q contains 16*nl weights
8526template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
8527kernel void kernel_mul_mm(
8528 constant ggml_metal_kargs_mul_mm & args,
8529 device const char * src0,
8530 device const char * src1,
8531 device char * dst,
8532 threadgroup char * shmem [[threadgroup(0)]],
8533 uint3 tgpig[[threadgroup_position_in_grid]],
8534 ushort tiitg[[thread_index_in_threadgroup]],
8535 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8536
8537 threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8538 threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8539
8540 threadgroup float * sc = (threadgroup float *)(shmem);
8541
8542 constexpr int NR0 = 64;
8543 constexpr int NR1 = 32;
8544
8545 constexpr int NK = 32;
8546 constexpr int NL0 = NK/16;
8547 constexpr int NL1 = NK/8;
8548
8549 const int im = tgpig.z;
8550 const int r0 = tgpig.y*NR0;
8551 const int r1 = tgpig.x*NR1;
8552
8553 // if this block is of 64x32 shape or smaller
8554 const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
8555 const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1;
8556
8557 // a thread shouldn't load data outside of the matrix
8558 const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
8559 const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
8560
8561 const short il0 = (tiitg % NL0);
8562
8563 short il = il0;
8564
8565 const int i12 = im%args.ne12;
8566 const int i13 = im/args.ne12;
8567
8568 const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
8569 const short offset1 = il0/nl;
8570
8571 device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
8572
8573 const short iy = 8*(tiitg % NL1);
8574
8575 device const T1 * y = (device const T1 *)(src1
8576 + args.nb13*i13
8577 + args.nb12*i12
8578 + args.nb11*(r1 + lr1)
8579 + args.nb10*iy);
8580
8581#ifndef GGML_METAL_HAS_TENSOR
8582 S0_8x8 ma[4];
8583 S1_8x8 mb[2];
8584
8585 simdgroup_float8x8 mc[8];
8586
8587 for (short i = 0; i < 8; i++){
8588 mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8589 }
8590#else
8591 auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8592 auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8593
8594 mpp::tensor_ops::matmul2d<
8595 mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8596 execution_simdgroups<4>> mm;
8597
8598 auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8599#endif
8600
8601 for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8602#ifndef GGML_METAL_HAS_TENSOR
8603 // load data and store to threadgroup memory
8604 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8605 threadgroup_barrier(mem_flags::mem_threadgroup);
8606
8607 // no need for dequantization
8608 for (short i = 0; i < 16; i++) {
8609 const short sx = 2*il0 + i/8;
8610 const short sy = (tiitg/NL0)/8;
8611
8612 //const short lx = i%8;
8613 //const short ly = (tiitg/NL0)%8;
8614 const short lx = (tiitg/NL0)%8;
8615 const short ly = i%8;
8616
8617 const short ib = 8*sx + sy;
8618
8619 *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8620 }
8621 } else {
8622 S0_4x4 temp_a;
8623 dequantize_func(x, il, temp_a);
8624
8625 threadgroup_barrier(mem_flags::mem_threadgroup);
8626
8627 FOR_UNROLL (short i = 0; i < 16; i++) {
8628 const short sx = 2*il0 + i/8;
8629 const short sy = (tiitg/NL0)/8;
8630
8631 //const short lx = i%8;
8632 //const short ly = (tiitg/NL0)%8;
8633 const short lx = (tiitg/NL0)%8;
8634 const short ly = i%8;
8635
8636 const short ib = 8*sx + sy;
8637
8638 // NOTE: this is massively slower.. WTF?
8639 //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8640
8641 *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
8642 }
8643 }
8644
8645 if (FC_mul_mm_bc_inp) {
8646 for (short i = 0; i < 8; ++i) {
8647 const short sx = (tiitg%NL1);
8648 const short sy = (tiitg/NL1)/8;
8649
8650 const short lx = i;
8651 const short ly = (tiitg/NL1)%8;
8652 //const short lx = (tiitg/NL1)%8;
8653 //const short ly = i;
8654
8655 const short ib = 4*sx + sy;
8656
8657 *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8658 }
8659 } else {
8660 const short sx = (tiitg%NL1);
8661 const short sy = (tiitg/NL1)/8;
8662
8663 const short dx = sx;
8664 const short dy = sy;
8665
8666 const short ly = (tiitg/NL1)%8;
8667
8668 const short ib = 4*sx + sy;
8669
8670 *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
8671 }
8672#else
8673 // load data and store to threadgroup memory
8674 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8675 threadgroup_barrier(mem_flags::mem_threadgroup);
8676
8677 // no need for dequantization
8678 for (short i = 0; i < 16; i++) {
8679 const short sx = 2*il0 + i/8;
8680 const short sy = (tiitg/NL0)/8;
8681
8682 const short lx = i%8;
8683 const short ly = (tiitg/NL0)%8;
8684 //const short lx = (tiitg/NL0)%8;
8685 //const short ly = i%8;
8686
8687 *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
8688 }
8689 } else {
8690 S0_4x4 temp_a;
8691 dequantize_func(x, il, temp_a);
8692
8693 threadgroup_barrier(mem_flags::mem_threadgroup);
8694
8695 FOR_UNROLL (short i = 0; i < 16; i++) {
8696 const short sx = 2*il0 + i/8;
8697 const short sy = (tiitg/NL0)/8;
8698
8699 const short lx = i%8;
8700 const short ly = (tiitg/NL0)%8;
8701 //const short lx = (tiitg/NL0)%8;
8702 //const short ly = i%8;
8703
8704 *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
8705 }
8706 }
8707
8708 if (FC_mul_mm_bc_inp) {
8709 for (short i = 0; i < 8; ++i) {
8710 const short sx = (tiitg%NL1);
8711 const short sy = (tiitg/NL1)/8;
8712
8713 const short lx = i;
8714 const short ly = (tiitg/NL1)%8;
8715 //const short lx = (tiitg/NL1)%8;
8716 //const short ly = i;
8717
8718 *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
8719 }
8720 } else {
8721 const short sx = (tiitg%NL1);
8722 const short sy = (tiitg/NL1)/8;
8723
8724 //const short lx = i;
8725 const short ly = (tiitg/NL1)%8;
8726 //const short lx = (tiitg/NL1)%8;
8727 //const short ly = i;
8728
8729 *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
8730 }
8731#endif
8732
8733 il = (il + 2 < nl) ? il + 2 : il % 2;
8734 x = (il < 2) ? x + (2 + nl - 1)/nl : x;
8735
8736 y += NK;
8737
8738 threadgroup_barrier(mem_flags::mem_threadgroup);
8739
8740#ifndef GGML_METAL_HAS_TENSOR
8741 // load matrices from threadgroup memory and conduct outer products
8742 threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
8743 threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
8744
8745 FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
8746 simdgroup_barrier(mem_flags::mem_none);
8747
8748 FOR_UNROLL (short i = 0; i < 4; i++) {
8749 simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
8750 }
8751
8752 simdgroup_barrier(mem_flags::mem_none);
8753
8754 FOR_UNROLL (short i = 0; i < 2; i++) {
8755 simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
8756 }
8757
8758 simdgroup_barrier(mem_flags::mem_none);
8759
8760 FOR_UNROLL (short i = 0; i < 8; i++){
8761 simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
8762 }
8763
8764 lsma += 8*64;
8765 lsmb += 4*64;
8766 }
8767#else
8768 auto sA = tA.slice(0, 0);
8769 auto sB = tB.slice(0, 0);
8770
8771 mm.run(sB, sA, cT);
8772#endif
8773 }
8774
8775 if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) {
8776 // if no bounds checks on the output are needed, we can directly write to device memory
8777#ifdef GGML_METAL_HAS_TENSOR
8778 device float * C = (device float *) dst +
8779 r0 + \
8780 r1 * args.ne0 + im*args.ne1*args.ne0;
8781
8782 auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1));
8783 cT.store(tC);
8784#else
8785 device float * C = (device float *) dst +
8786 (r0 + 32*(sgitg & 1)) + \
8787 (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
8788
8789 for (short i = 0; i < 8; i++) {
8790 simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false);
8791 }
8792#endif
8793 } else {
8794 // block is smaller than 64x32, we should avoid writing data outside of the matrix
8795 threadgroup_barrier(mem_flags::mem_threadgroup);
8796
8797 threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
8798
8799#ifdef GGML_METAL_HAS_TENSOR
8800 auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
8801 cT.store(tC);
8802#else
8803 for (short i = 0; i < 8; i++) {
8804 simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
8805 }
8806#endif
8807
8808 threadgroup_barrier(mem_flags::mem_threadgroup);
8809
8810 if (sgitg == 0) {
8811 for (int j = tiitg; j < nr1; j += NR1) {
8812 device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0;
8813 device float4 * D4 = (device float4 *) D;
8814
8815 threadgroup float * C = temp_str + (j*NR0);
8816 threadgroup float4 * C4 = (threadgroup float4 *) C;
8817
8818 int i = 0;
8819 for (; i < nr0/4; i++) {
8820 *(D4 + i) = *(C4 + i);
8821 }
8822
8823 i *= 4;
8824 for (; i < nr0; i++) {
8825 *(D + i) = *(C + i);
8826 }
8827 }
8828 }
8829 }
8830}
8831
8832template<short ne20> // n_expert_used
8833kernel void kernel_mul_mm_id_map0(
8834 constant ggml_metal_kargs_mul_mm_id_map0 & args,
8835 device const char * src2,
8836 device char * htpe,
8837 device char * hids,
8838 threadgroup char * shmem [[threadgroup(0)]],
8839 ushort tpitg[[thread_position_in_threadgroup]],
8840 ushort ntg[[threads_per_threadgroup]]) {
8841 const short ide = tpitg; // expert id
8842
8843 uint32_t n_all = 0;
8844
8845 device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21;
8846
8847 for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens
8848 if (i21 + tpitg < args.ne21) {
8849 device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21);
8850
8851 threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20;
8852
8853 #pragma unroll(ne20)
8854 for (short i20 = 0; i20 < ne20; i20++) {
8855 sids[i20] = src2_i32[i20];
8856 }
8857 }
8858
8859 threadgroup_barrier(mem_flags::mem_threadgroup);
8860
8861 for (short t = 0; t < ntg; t++) {
8862 if (i21 + t >= args.ne21) {
8863 break;
8864 }
8865
8866 threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20;
8867
8868 short sel = 0;
8869 #pragma unroll(ne20)
8870 for (short i20 = 0; i20 < ne20; i20++) {
8871 sel += (sids[i20] == ide)*(i20 + 1);
8872 }
8873
8874 ids_i32[n_all] = (i21 + t)*ne20 + sel - 1;
8875
8876 n_all += sel > 0;
8877 }
8878
8879 threadgroup_barrier(mem_flags::mem_threadgroup);
8880 }
8881
8882 device uint32_t * tpe_u32 = (device uint32_t *) (htpe);
8883 tpe_u32[ide] = n_all;
8884}
8885
8886typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t;
8887
8888template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>;
8889template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>;
8890template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
8891template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>;
8892template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
8893template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
8894template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
8895template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
8896
8897template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4>
8898kernel void kernel_mul_mm_id(
8899 constant ggml_metal_kargs_mul_mm_id & args,
8900 device const char * src0,
8901 device const char * src1,
8902 device const char * htpe,
8903 device const char * hids,
8904 device char * dst,
8905 threadgroup char * shmem [[threadgroup(0)]],
8906 uint3 tgpig[[threadgroup_position_in_grid]],
8907 ushort tiitg[[thread_index_in_threadgroup]],
8908 ushort tiisg[[thread_index_in_simdgroup]],
8909 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8910 threadgroup S0 * sa = (threadgroup S0 *)(shmem);
8911 threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096);
8912
8913 threadgroup float * sc = (threadgroup float *)(shmem);
8914
8915 constexpr int NR0 = 64;
8916 constexpr int NR1 = 32;
8917
8918 constexpr int NK = 32;
8919 constexpr int NL0 = NK/16;
8920 constexpr int NL1 = NK/8;
8921
8922 const int im = tgpig.z; // expert
8923 const int r0 = tgpig.y*NR0;
8924 const int r1 = tgpig.x*NR1;
8925
8926 device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
8927 device const int32_t * ids_i32 = (device const int32_t *) (hids);
8928
8929 const int32_t neh1 = tpe_u32[im];
8930
8931 if (r1 >= neh1) {
8932 return;
8933 }
8934
8935 // if this block is of 64x32 shape or smaller
8936 const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0;
8937 const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1;
8938
8939 // a thread shouldn't load data outside of the matrix
8940 const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63
8941 const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31
8942
8943 const short il0 = (tiitg % NL0);
8944
8945 short il = il0;
8946
8947 const int id = ids_i32[im*args.ne21 + r1 + lr1];
8948
8949 const short i11 = (id % args.ne20) % args.ne11;
8950 const short i12 = (id / args.ne20);
8951 const short i13 = 0;
8952
8953 const uint64_t offset0 = im*args.nb02 + i13*args.nb03;
8954 const short offset1 = il0/nl;
8955
8956 device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1;
8957
8958 const short iy = 8*(tiitg % NL1);
8959
8960 device const T1 * y = (device const T1 *)(src1
8961 + args.nb13*i13
8962 + args.nb12*i12
8963 + args.nb11*i11
8964 + args.nb10*iy);
8965
8966#ifndef GGML_METAL_HAS_TENSOR
8967 S0_8x8 ma[4];
8968 S1_8x8 mb[2];
8969
8970 simdgroup_float8x8 mc[8];
8971
8972 for (short i = 0; i < 8; i++){
8973 mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
8974 }
8975#else
8976 auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0));
8977 auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK ));
8978
8979 mpp::tensor_ops::matmul2d<
8980 mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8981 execution_simdgroups<4>> mm;
8982
8983 auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>();
8984#endif
8985
8986 for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) {
8987#ifndef GGML_METAL_HAS_TENSOR
8988 // load data and store to threadgroup memory
8989 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8990 threadgroup_barrier(mem_flags::mem_threadgroup);
8991
8992 // no need for dequantization
8993 for (short i = 0; i < 16; i++) {
8994 const short sx = 2*il0 + i/8;
8995 const short sy = (tiitg/NL0)/8;
8996
8997 //const short lx = i%8;
8998 //const short ly = (tiitg/NL0)%8;
8999 const short lx = (tiitg/NL0)%8;
9000 const short ly = i%8;
9001
9002 const short ib = 8*sx + sy;
9003
9004 *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
9005 }
9006 } else {
9007 S0_4x4 temp_a;
9008 dequantize_func(x, il, temp_a);
9009
9010 threadgroup_barrier(mem_flags::mem_threadgroup);
9011
9012 FOR_UNROLL (short i = 0; i < 16; i++) {
9013 const short sx = 2*il0 + i/8;
9014 const short sy = (tiitg/NL0)/8;
9015
9016 //const short lx = i%8;
9017 //const short ly = (tiitg/NL0)%8;
9018 const short lx = (tiitg/NL0)%8;
9019 const short ly = i%8;
9020
9021 const short ib = 8*sx + sy;
9022
9023 // NOTE: this is massively slower.. WTF?
9024 //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
9025
9026 *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4];
9027 }
9028 }
9029
9030 if (FC_mul_mm_bc_inp) {
9031 for (short i = 0; i < 8; ++i) {
9032 const short sx = (tiitg%NL1);
9033 const short sy = (tiitg/NL1)/8;
9034
9035 const short lx = i;
9036 const short ly = (tiitg/NL1)%8;
9037 //const short lx = (tiitg/NL1)%8;
9038 //const short ly = i;
9039
9040 const short ib = 4*sx + sy;
9041
9042 *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
9043 }
9044 } else {
9045 const short sx = (tiitg%NL1);
9046 const short sy = (tiitg/NL1)/8;
9047
9048 const short dx = sx;
9049 const short dy = sy;
9050
9051 const short ly = (tiitg/NL1)%8;
9052
9053 const short ib = 4*sx + sy;
9054
9055 *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y));
9056 }
9057#else
9058 // load data and store to threadgroup memory
9059 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
9060 threadgroup_barrier(mem_flags::mem_threadgroup);
9061
9062 // no need for dequantization
9063 for (short i = 0; i < 16; i++) {
9064 const short sx = 2*il0 + i/8;
9065 const short sy = (tiitg/NL0)/8;
9066
9067 const short lx = i%8;
9068 const short ly = (tiitg/NL0)%8;
9069 //const short lx = (tiitg/NL0)%8;
9070 //const short ly = i%8;
9071
9072 *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0;
9073 }
9074 } else {
9075 S0_4x4 temp_a;
9076 dequantize_func(x, il, temp_a);
9077
9078 threadgroup_barrier(mem_flags::mem_threadgroup);
9079
9080 FOR_UNROLL (short i = 0; i < 16; i++) {
9081 const short sx = 2*il0 + i/8;
9082 const short sy = (tiitg/NL0)/8;
9083
9084 const short lx = i%8;
9085 const short ly = (tiitg/NL0)%8;
9086 //const short lx = (tiitg/NL0)%8;
9087 //const short ly = i%8;
9088
9089 *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4];
9090 }
9091 }
9092
9093 if (FC_mul_mm_bc_inp) {
9094 for (short i = 0; i < 8; ++i) {
9095 const short sx = (tiitg%NL1);
9096 const short sy = (tiitg/NL1)/8;
9097
9098 const short lx = i;
9099 const short ly = (tiitg/NL1)%8;
9100 //const short lx = (tiitg/NL1)%8;
9101 //const short ly = i;
9102
9103 *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0;
9104 }
9105 } else {
9106 const short sx = (tiitg%NL1);
9107 const short sy = (tiitg/NL1)/8;
9108
9109 //const short lx = i;
9110 const short ly = (tiitg/NL1)%8;
9111 //const short lx = (tiitg/NL1)%8;
9112 //const short ly = i;
9113
9114 *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y));
9115 }
9116#endif
9117
9118 il = (il + 2 < nl) ? il + 2 : il % 2;
9119 x = (il < 2) ? x + (2 + nl - 1)/nl : x;
9120
9121 y += NK;
9122
9123 threadgroup_barrier(mem_flags::mem_threadgroup);
9124
9125#ifndef GGML_METAL_HAS_TENSOR
9126 // load matrices from threadgroup memory and conduct outer products
9127 threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2));
9128 threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2));
9129
9130 FOR_UNROLL (short ik = 0; ik < NK/8; ik++) {
9131 simdgroup_barrier(mem_flags::mem_none);
9132
9133 FOR_UNROLL (short i = 0; i < 4; i++) {
9134 simdgroup_load(ma[i], lsma + 64*i, 8, 0, false);
9135 }
9136
9137 simdgroup_barrier(mem_flags::mem_none);
9138
9139 FOR_UNROLL (short i = 0; i < 2; i++) {
9140 simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false);
9141 }
9142
9143 simdgroup_barrier(mem_flags::mem_none);
9144
9145 FOR_UNROLL (short i = 0; i < 8; i++){
9146 simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
9147 }
9148
9149 lsma += 8*64;
9150 lsmb += 4*64;
9151 }
9152#else
9153 auto sA = tA.slice(0, 0);
9154 auto sB = tB.slice(0, 0);
9155
9156 mm.run(sB, sA, cT);
9157#endif
9158 }
9159
9160 // block is smaller than 64x32, we should avoid writing data outside of the matrix
9161 threadgroup_barrier(mem_flags::mem_threadgroup);
9162
9163#ifdef GGML_METAL_HAS_TENSOR
9164 auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1));
9165 cT.store(tC);
9166#else
9167 threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0;
9168
9169 for (short i = 0; i < 8; i++) {
9170 simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false);
9171 }
9172#endif
9173
9174 threadgroup_barrier(mem_flags::mem_threadgroup);
9175
9176 for (short j = sgitg; j < nr1; j += 4) {
9177 const int id = ids_i32[im*args.ne21 + r1 + j];
9178
9179 const short ide = id % args.ne20;
9180 const short idt = id / args.ne20;
9181
9182 device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0;
9183 device float4 * D4 = (device float4 *) D;
9184
9185 threadgroup float * C = (threadgroup float *) shmem + j*NR0;
9186 threadgroup float4 * C4 = (threadgroup float4 *) C;
9187
9188 int i = tiisg;
9189 for (; i < nr0/4; i += 32) {
9190 *(D4 + i) = *(C4 + i);
9191 }
9192
9193 i = (4*(nr0/4)) + tiisg;
9194 for (; i < nr0; i += 32) {
9195 *(D + i) = *(C + i);
9196 }
9197 }
9198}
9199
9200#define QK_NL 16
9201
9202//
9203// get rows
9204//
9205
9206typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
9207
9208template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
9209template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
9210template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
9211#if defined(GGML_METAL_HAS_BF16)
9212template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
9213#endif
9214
9215typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
9216
9217template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
9218template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
9219template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
9220template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
9221template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
9222template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
9223template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
9224template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
9225template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
9226template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
9227template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
9228template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
9229template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
9230template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
9231template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
9232template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
9233template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
9234template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
9235template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
9236template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
9237
9238//
9239// set rows
9240//
9241
9242typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
9243
9244template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
9245template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
9246template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
9247template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
9248#if defined(GGML_METAL_HAS_BF16)
9249template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
9250template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
9251#endif
9252
9253typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
9254
9255template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
9256template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
9257template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
9258template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
9259template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
9260template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
9261template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
9262template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
9263template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
9264template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
9265template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
9266template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
9267
9268//
9269// matrix-matrix multiplication
9270//
9271
9272typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t;
9273
9274template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
9275template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
9276#if defined(GGML_METAL_HAS_BF16)
9277template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
9278#endif
9279template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
9280template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
9281template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
9282template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
9283template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
9284template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
9285template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
9286template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
9287template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
9288template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
9289template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
9290template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
9291template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
9292template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
9293template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
9294template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
9295template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
9296template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
9297template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
9298template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
9299
9300template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
9301template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
9302template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
9303template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
9304template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
9305template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
9306template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
9307template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
9308template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
9309template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
9310template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
9311template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
9312template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
9313template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
9314template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
9315template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
9316template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
9317template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
9318template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
9319template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
9320template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
9321template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
9322
9323//
9324// indirect matrix-matrix multiplication
9325//
9326
9327typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id;
9328
9329template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>;
9330template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>;
9331#if defined(GGML_METAL_HAS_BF16)
9332template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>;
9333#endif
9334template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>;
9335template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>;
9336template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>;
9337template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>;
9338template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>;
9339template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>;
9340template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>;
9341template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>;
9342template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>;
9343template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>;
9344template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>;
9345template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>;
9346template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>;
9347template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>;
9348template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>;
9349template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>;
9350template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>;
9351template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>;
9352template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>;
9353template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>;
9354
9355template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
9356template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
9357template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
9358template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
9359template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
9360template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>;
9361template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>;
9362template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>;
9363template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>;
9364template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>;
9365template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>;
9366template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>;
9367template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>;
9368template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>;
9369template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>;
9370template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>;
9371template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>;
9372template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>;
9373template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>;
9374template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>;
9375template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>;
9376template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>;
9377
9378//
9379// matrix-vector multiplication
9380//
9381
9382typedef void (kernel_mul_mv_disp_t)(
9383 ggml_metal_kargs_mul_mv args,
9384 device const char * src0,
9385 device const char * src1,
9386 device char * dst,
9387 uint3 tgpig,
9388 ushort tiisg);
9389
9390typedef void (kernel_mul_mv2_disp_t)(
9391 ggml_metal_kargs_mul_mv args,
9392 device const char * src0,
9393 device const char * src1,
9394 device char * dst,
9395 threadgroup char * shmem,
9396 uint3 tgpig,
9397 ushort tiisg,
9398 ushort sgitg);
9399
9400template<kernel_mul_mv_disp_t disp_fn>
9401void mmv_fn(
9402 ggml_metal_kargs_mul_mv args,
9403 device const char * src0,
9404 device const char * src1,
9405 device char * dst,
9406 threadgroup char * shmem,
9407 uint3 tgpig,
9408 ushort tiitg,
9409 ushort tiisg,
9410 ushort sgitg) {
9411 disp_fn(args, src0, src1, dst, tgpig, tiisg);
9412}
9413
9414template<kernel_mul_mv2_disp_t disp_fn>
9415void mmv_fn(
9416 ggml_metal_kargs_mul_mv args,
9417 device const char * src0,
9418 device const char * src1,
9419 device char * dst,
9420 threadgroup char * shmem,
9421 uint3 tgpig,
9422 ushort tiitg,
9423 ushort tiisg,
9424 ushort sgitg) {
9425 disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
9426}
9427
9428typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t;
9429
9430template<mul_mv_disp_fn_t disp_fn>
9431kernel void kernel_mul_mv_id(
9432 constant ggml_metal_kargs_mul_mv_id & args,
9433 device const char * src0s,
9434 device const char * src1,
9435 device char * dst,
9436 device const char * ids,
9437 threadgroup char * shmem [[threadgroup(0)]],
9438 uint3 tgpig[[threadgroup_position_in_grid]],
9439 ushort tiitg[[thread_index_in_threadgroup]],
9440 ushort tiisg[[thread_index_in_simdgroup]],
9441 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
9442 const int iid1 = tgpig.z/args.nei0;
9443 const int idx = tgpig.z%args.nei0;
9444
9445 tgpig.z = 0;
9446
9447 const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
9448
9449 const int64_t i11 = idx % args.ne11;
9450 const int64_t i12 = iid1;
9451
9452 const int64_t i1 = idx;
9453 const int64_t i2 = i12;
9454
9455 device const char * src0_cur = src0s + i02*args.nb02;
9456 device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
9457
9458 device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
9459
9460 ggml_metal_kargs_mul_mv args0 = {
9461 /*.ne00 =*/ args.ne00,
9462 /*.ne01 =*/ args.ne01,
9463 /*.ne02 =*/ 1, // args.ne02,
9464 /*.nb00 =*/ args.nb00,
9465 /*.nb01 =*/ args.nb01,
9466 /*.nb02 =*/ args.nb02,
9467 /*.nb03 =*/ args.nb02, // args.ne02 == 1
9468 /*.ne10 =*/ args.ne10,
9469 /*.ne11 =*/ 1, // args.ne11,
9470 /*.ne12 =*/ 1, // args.ne12,
9471 /*.nb10 =*/ args.nb10,
9472 /*.nb11 =*/ args.nb11,
9473 /*.nb12 =*/ args.nb12,
9474 /*.nb13 =*/ args.nb12, // ne12 == 1
9475 /*.ne0 =*/ args.ne0,
9476 /*.ne1 =*/ 1, // args.ne1,
9477 /*.nr0 =*/ args.nr0,
9478 /*.r2 =*/ 1,
9479 /*.r3 =*/ 1,
9480 };
9481
9482 disp_fn(
9483 args0,
9484 /* src0 */ src0_cur,
9485 /* src1 */ src1_cur,
9486 /* dst */ dst_cur,
9487 shmem,
9488 tgpig,
9489 tiitg,
9490 tiisg,
9491 sgitg);
9492}
9493
9494typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t;
9495
9496typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t;
9497
9498template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>;
9499template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>;
9500#if defined(GGML_METAL_HAS_BF16)
9501template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>;
9502#endif
9503template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>;
9504template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>;
9505#if defined(GGML_METAL_HAS_BF16)
9506template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>;
9507#endif
9508
9509template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
9510
9511template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
9512template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
9513template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
9514template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
9515
9516template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
9517
9518template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
9519template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
9520template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;
9521template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K>>>;
9522template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K>>>;
9523template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S>>>;
9524template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M>>>;
9525template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
9526template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
9527template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
9528template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S>>>;
9529template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S>>>;
9530template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
9531template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
9532
9533kernel void kernel_pool_2d_max_f32(
9534 constant ggml_metal_kargs_pool_2d & args,
9535 device const float * src0,
9536 device float * dst,
9537 uint gid[[thread_position_in_grid]]) {
9538
9539 if (gid >= args.np) {
9540 return;
9541 }
9542
9543 const int idx = gid;
9544 const int I_HW = args.IH * args.IW;
9545 const int O_HW = args.OH * args.OW;
9546 const int nc = idx / O_HW;
9547 const int cur_oh = idx % O_HW / args.OW;
9548 const int cur_ow = idx % O_HW % args.OW;
9549
9550 device const float * i_ptr = src0 + nc * I_HW;
9551 device float * o_ptr = dst + nc * O_HW;
9552
9553 const int start_h = cur_oh * args.s1 - args.p1;
9554 const int bh = MAX(0, start_h);
9555 const int eh = MIN(args.IH, start_h + args.k1);
9556 const int start_w = cur_ow * args.s0 - args.p0;
9557 const int bw = MAX(0, start_w);
9558 const int ew = MIN(args.IW, start_w + args.k0);
9559
9560 float res = -INFINITY;
9561
9562 for (int i = bh; i < eh; i += 1) {
9563 for (int j = bw; j < ew; j += 1) {
9564 res = MAX(res, i_ptr[i * args.IW + j]);
9565 }
9566 }
9567
9568 o_ptr[cur_oh * args.OW + cur_ow] = res;
9569}
9570
9571kernel void kernel_pool_2d_avg_f32(
9572 constant ggml_metal_kargs_pool_2d & args,
9573 device const float * src0,
9574 device float * dst,
9575 uint gid[[thread_position_in_grid]]) {
9576
9577 if (gid >= args.np) {
9578 return;
9579 }
9580
9581 const int idx = gid;
9582 const int I_HW = args.IH * args.IW;
9583 const int O_HW = args.OH * args.OW;
9584 const int nc = idx / O_HW;
9585 const int cur_oh = idx % O_HW / args.OW;
9586 const int cur_ow = idx % O_HW % args.OW;
9587
9588 device const float * i_ptr = src0 + nc * I_HW;
9589 device float * o_ptr = dst + nc * O_HW;
9590
9591 const int start_h = cur_oh * args.s1 - args.p1;
9592 const int bh = MAX(0, start_h);
9593 const int eh = MIN(args.IH, start_h + args.k1);
9594 const int start_w = cur_ow * args.s0 - args.p0;
9595 const int bw = MAX(0, start_w);
9596 const int ew = MIN(args.IW, start_w + args.k0);
9597 // const float scale = 1. / ((eh - bh) * (ew - bw));
9598 const float scale = 1. / (args.k0 * args.k1);
9599
9600 float res = 0;
9601
9602 for (int i = bh; i < eh; i += 1) {
9603 for (int j = bw; j < ew; j += 1) {
9604 float cur = i_ptr[i * args.IW + j];
9605 res += cur * scale;
9606 }
9607 }
9608
9609 o_ptr[cur_oh * args.OW + cur_ow] = res;
9610}
9611
9612
9613kernel void kernel_pool_1d_max_f32(
9614 constant ggml_metal_kargs_pool_1d & args,
9615 device const float * src,
9616 device float * dst,
9617 uint gid [[thread_position_in_grid]]
9618) {
9619
9620 if (gid >= args.np) {
9621 return;
9622 }
9623
9624 const int ow = (int)gid % args.OW;
9625 const int row = (int)gid / args.OW;
9626
9627 const int base = ow * args.s0 - args.p0;
9628
9629 float acc = -INFINITY;
9630
9631 const int src_off = row * args.IW;
9632 const int dst_off = row * args.OW;
9633
9634 for (int ki = 0; ki < args.k0; ++ki) {
9635 int j = base + ki;
9636 if (j < 0 || j >= args.IW){
9637 continue;
9638 }
9639 float v = src[src_off + j];
9640 acc = max(acc, v);
9641 }
9642
9643 dst[dst_off + ow] = acc;
9644}
9645
9646kernel void kernel_pool_1d_avg_f32(
9647 constant ggml_metal_kargs_pool_1d & args,
9648 device const float * src,
9649 device float * dst,
9650 uint gid [[thread_position_in_grid]]
9651) {
9652
9653 if (gid >= args.np) {
9654 return;
9655 }
9656
9657 const int ow = (int)gid % args.OW;
9658 const int row = (int)gid / args.OW;
9659
9660 const int base = ow * args.s0 - args.p0;
9661
9662 float acc = 0.0f;
9663 int cnt = 0;
9664
9665 const int src_off = row * args.IW;
9666 const int dst_off = row * args.OW;
9667
9668 for (int ki = 0; ki < args.k0; ++ki) {
9669 const int j = base + ki;
9670 if (j < 0 || j >= args.IW) {
9671 continue;
9672 }
9673 acc += src[src_off + j];
9674 cnt += 1;
9675 }
9676
9677 dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
9678}
9679
9680kernel void kernel_opt_step_adamw_f32(
9681 constant ggml_metal_kargs_opt_step_adamw & args,
9682 device float * x,
9683 device const float * g,
9684 device float * g_m,
9685 device float * g_v,
9686 device const float * pars,
9687 uint gid[[thread_position_in_grid]]) {
9688
9689 if (gid >= args.np) {
9690 return;
9691 }
9692
9693 const float alpha = pars[0];
9694 const float beta1 = pars[1];
9695 const float beta2 = pars[2];
9696 const float eps = pars[3];
9697 const float wd = pars[4];
9698 const float beta1h = pars[5];
9699 const float beta2h = pars[6];
9700
9701 const float gi = g[gid];
9702 const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
9703 const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
9704
9705 g_m[gid] = gmi;
9706 g_v[gid] = gvi;
9707
9708 const float mh = gmi * beta1h;
9709 const float vh = sqrt(gvi * beta2h) + eps;
9710
9711 x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
9712}
9713
9714kernel void kernel_opt_step_sgd_f32(
9715 constant ggml_metal_kargs_opt_step_sgd & args,
9716 device float * x,
9717 device const float * g,
9718 device const float * pars,
9719 uint gid[[thread_position_in_grid]]) {
9720
9721 if (gid >= args.np) {
9722 return;
9723 }
9724
9725 x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
9726}
9727
9728template<typename T>
9729kernel void kernel_memset(
9730 constant ggml_metal_kargs_memset & args,
9731 device T * dst,
9732 uint tpig[[thread_position_in_grid]]) {
9733 dst[tpig] = args.val;
9734}
9735
9736typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
9737
9738template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
9739
9740constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
9741
9742template<typename T>
9743kernel void kernel_count_equal(
9744 constant ggml_metal_kargs_count_equal & args,
9745 device const char * src0,
9746 device const char * src1,
9747 device atomic_int * dst,
9748 threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
9749 uint3 tgpig[[threadgroup_position_in_grid]],
9750 ushort3 tpitg[[thread_position_in_threadgroup]],
9751 ushort sgitg[[simdgroup_index_in_threadgroup]],
9752 ushort tiisg[[thread_index_in_simdgroup]],
9753 ushort3 ntg[[threads_per_threadgroup]]) {
9754 const short NSG = FC_count_equal_nsg;
9755
9756 const int i3 = tgpig.z;
9757 const int i2 = tgpig.y;
9758 const int i1 = tgpig.x;
9759
9760 if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
9761 return;
9762 }
9763
9764 int sum = 0;
9765
9766 device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
9767 device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
9768
9769 for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
9770 const T v0 = *(device const T *)(base0 + i0*args.nb00);
9771 const T v1 = *(device const T *)(base1 + i0*args.nb10);
9772 sum += (v0 == v1);
9773 }
9774
9775 sum = simd_sum(sum);
9776
9777 if (tiisg == 0) {
9778 shmem_i32[sgitg] = sum;
9779 }
9780
9781 threadgroup_barrier(mem_flags::mem_threadgroup);
9782
9783 if (sgitg == 0) {
9784 float v = 0.0f;
9785 if (tpitg.x < NSG) {
9786 v = shmem_i32[tpitg.x];
9787 }
9788
9789 float total = simd_sum(v);
9790 if (tpitg.x == 0) {
9791 atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
9792 }
9793 }
9794}
9795
9796typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
9797
9798template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;