diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-metal | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-metal')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/CMakeLists.txt | 124 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp | 446 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h | 52 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h | 41 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m | 702 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp | 1875 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h | 290 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m | 1748 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h | 1051 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp | 4222 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h | 93 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp | 937 | ||||
| -rw-r--r-- | llama.cpp/ggml/src/ggml-metal/ggml-metal.metal | 9798 |
13 files changed, 21379 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt b/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt new file mode 100644 index 0000000..42054d8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt | |||
| @@ -0,0 +1,124 @@ | |||
| 1 | find_library(FOUNDATION_LIBRARY Foundation REQUIRED) | ||
| 2 | find_library(METAL_FRAMEWORK Metal REQUIRED) | ||
| 3 | find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) | ||
| 4 | |||
| 5 | message(STATUS "Metal framework found") | ||
| 6 | |||
| 7 | ggml_add_backend_library(ggml-metal | ||
| 8 | ggml-metal.cpp | ||
| 9 | ggml-metal-device.m | ||
| 10 | ggml-metal-device.cpp | ||
| 11 | ggml-metal-common.cpp | ||
| 12 | ggml-metal-context.m | ||
| 13 | ggml-metal-ops.cpp | ||
| 14 | ) | ||
| 15 | |||
| 16 | target_link_libraries(ggml-metal PRIVATE | ||
| 17 | ${FOUNDATION_LIBRARY} | ||
| 18 | ${METAL_FRAMEWORK} | ||
| 19 | ${METALKIT_FRAMEWORK} | ||
| 20 | ) | ||
| 21 | |||
| 22 | if (GGML_METAL_NDEBUG) | ||
| 23 | add_compile_definitions(GGML_METAL_NDEBUG) | ||
| 24 | endif() | ||
| 25 | |||
| 26 | set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") | ||
| 27 | if (GGML_METAL_EMBED_LIBRARY) | ||
| 28 | enable_language(ASM) | ||
| 29 | |||
| 30 | add_compile_definitions(GGML_METAL_EMBED_LIBRARY) | ||
| 31 | |||
| 32 | set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") | ||
| 33 | set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") | ||
| 34 | |||
| 35 | file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated") | ||
| 36 | |||
| 37 | # merge ggml-common.h and ggml-metal.metal into a single file | ||
| 38 | set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s") | ||
| 39 | set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") | ||
| 40 | set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") | ||
| 41 | |||
| 42 | add_custom_command( | ||
| 43 | OUTPUT "${METALLIB_EMBED_ASM}" | ||
| 44 | COMMAND echo "Embedding Metal library" | ||
| 45 | COMMAND sed -e "/__embed_ggml-common.h__/r ${METALLIB_COMMON}" -e "/__embed_ggml-common.h__/d" < "${METALLIB_SOURCE}" > "${METALLIB_SOURCE_EMBED_TMP}" | ||
| 46 | COMMAND sed -e "/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}" -e "/\#include \"ggml-metal-impl.h\"/d" < "${METALLIB_SOURCE_EMBED_TMP}" > "${METALLIB_SOURCE_EMBED}" | ||
| 47 | COMMAND echo ".section __DATA,__ggml_metallib" > "${METALLIB_EMBED_ASM}" | ||
| 48 | COMMAND echo ".globl _ggml_metallib_start" >> "${METALLIB_EMBED_ASM}" | ||
| 49 | COMMAND echo "_ggml_metallib_start:" >> "${METALLIB_EMBED_ASM}" | ||
| 50 | COMMAND echo .incbin "\"${METALLIB_SOURCE_EMBED}\"" >> "${METALLIB_EMBED_ASM}" | ||
| 51 | COMMAND echo ".globl _ggml_metallib_end" >> "${METALLIB_EMBED_ASM}" | ||
| 52 | COMMAND echo "_ggml_metallib_end:" >> "${METALLIB_EMBED_ASM}" | ||
| 53 | DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h | ||
| 54 | COMMENT "Generate assembly for embedded Metal library" | ||
| 55 | VERBATIM | ||
| 56 | ) | ||
| 57 | |||
| 58 | target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}") | ||
| 59 | else() | ||
| 60 | # copy metal files to bin directory | ||
| 61 | configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) | ||
| 62 | configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) | ||
| 63 | configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) | ||
| 64 | |||
| 65 | if (GGML_METAL_SHADER_DEBUG) | ||
| 66 | # custom command to do the following: | ||
| 67 | # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air | ||
| 68 | # xcrun -sdk macosx metallib ggml-metal.air -o default.metallib | ||
| 69 | # | ||
| 70 | # note: this is the only way I found to disable fast-math in Metal. it's ugly, but at least it works | ||
| 71 | # disabling fast math is needed in order to pass tests/test-backend-ops | ||
| 72 | # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 | ||
| 73 | # note: unfortunately, we have to call it default.metallib instead of ggml.metallib | ||
| 74 | # ref: https://github.com/ggml-org/whisper.cpp/issues/1720 | ||
| 75 | # note: adding -g causes segmentation fault during compile | ||
| 76 | #set(XC_FLAGS -fno-fast-math -fno-inline -g) | ||
| 77 | set(XC_FLAGS -fno-fast-math -fno-inline) | ||
| 78 | else() | ||
| 79 | set(XC_FLAGS -O3) | ||
| 80 | endif() | ||
| 81 | |||
| 82 | # Append macOS metal versioning flags | ||
| 83 | if (GGML_METAL_MACOSX_VERSION_MIN) | ||
| 84 | message(STATUS "Adding -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN} flag to metal compilation") | ||
| 85 | list (APPEND XC_FLAGS -mmacosx-version-min=${GGML_METAL_MACOSX_VERSION_MIN}) | ||
| 86 | endif() | ||
| 87 | |||
| 88 | if (GGML_METAL_STD) | ||
| 89 | message(STATUS "Adding -std=${GGML_METAL_STD} flag to metal compilation") | ||
| 90 | list (APPEND XC_FLAGS -std=${GGML_METAL_STD}) | ||
| 91 | endif() | ||
| 92 | |||
| 93 | add_custom_command( | ||
| 94 | OUTPUT ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib | ||
| 95 | COMMAND xcrun -sdk macosx metal ${XC_FLAGS} -c ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal -o - | | ||
| 96 | xcrun -sdk macosx metallib - -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib | ||
| 97 | COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h | ||
| 98 | COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal | ||
| 99 | DEPENDS ggml-metal.metal ${METALLIB_COMMON} | ||
| 100 | COMMENT "Compiling Metal kernels" | ||
| 101 | ) | ||
| 102 | |||
| 103 | # FIXME: only add to the ggml-metal target? | ||
| 104 | add_custom_target( | ||
| 105 | ggml-metal-lib ALL | ||
| 106 | DEPENDS ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib | ||
| 107 | ) | ||
| 108 | endif() # GGML_METAL_EMBED_LIBRARY | ||
| 109 | |||
| 110 | if (NOT GGML_METAL_EMBED_LIBRARY) | ||
| 111 | install( | ||
| 112 | FILES src/ggml-metal/ggml-metal.metal | ||
| 113 | PERMISSIONS | ||
| 114 | OWNER_READ | ||
| 115 | OWNER_WRITE | ||
| 116 | GROUP_READ | ||
| 117 | WORLD_READ | ||
| 118 | DESTINATION ${CMAKE_INSTALL_BINDIR}) | ||
| 119 | |||
| 120 | install( | ||
| 121 | FILES ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib | ||
| 122 | DESTINATION ${CMAKE_INSTALL_BINDIR} | ||
| 123 | ) | ||
| 124 | endif() | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp new file mode 100644 index 0000000..95627d3 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.cpp | |||
| @@ -0,0 +1,446 @@ | |||
| 1 | #include "ggml-metal-common.h" | ||
| 2 | |||
| 3 | #include "ggml-impl.h" | ||
| 4 | #include "ggml-backend-impl.h" | ||
| 5 | |||
| 6 | #include <vector> | ||
| 7 | |||
| 8 | // represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb) | ||
| 9 | // the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it) | ||
| 10 | struct ggml_mem_range { | ||
| 11 | uint64_t pb; // buffer id | ||
| 12 | |||
| 13 | uint64_t p0; // begin | ||
| 14 | uint64_t p1; // end | ||
| 15 | |||
| 16 | ggml_mem_range_type pt; | ||
| 17 | }; | ||
| 18 | |||
| 19 | struct ggml_mem_ranges { | ||
| 20 | std::vector<ggml_mem_range> ranges; | ||
| 21 | |||
| 22 | int debug = 0; | ||
| 23 | }; | ||
| 24 | |||
| 25 | ggml_mem_ranges_t ggml_mem_ranges_init(int debug) { | ||
| 26 | auto * res = new ggml_mem_ranges; | ||
| 27 | |||
| 28 | res->ranges.reserve(256); | ||
| 29 | res->debug = debug; | ||
| 30 | |||
| 31 | return res; | ||
| 32 | } | ||
| 33 | |||
| 34 | void ggml_mem_ranges_free(ggml_mem_ranges_t mrs) { | ||
| 35 | delete mrs; | ||
| 36 | } | ||
| 37 | |||
| 38 | void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs) { | ||
| 39 | mrs->ranges.clear(); | ||
| 40 | } | ||
| 41 | |||
| 42 | static bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, ggml_mem_range mr) { | ||
| 43 | mrs->ranges.push_back(mr); | ||
| 44 | |||
| 45 | return true; | ||
| 46 | } | ||
| 47 | |||
| 48 | static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) { | ||
| 49 | // always use the base tensor | ||
| 50 | tensor = tensor->view_src ? tensor->view_src : tensor; | ||
| 51 | |||
| 52 | GGML_ASSERT(!tensor->view_src); | ||
| 53 | |||
| 54 | ggml_mem_range mr; | ||
| 55 | |||
| 56 | if (tensor->buffer) { | ||
| 57 | // when the tensor is allocated, use the actual memory address range in the buffer | ||
| 58 | // | ||
| 59 | // take the actual allocated size with ggml_backend_buft_get_alloc_size() | ||
| 60 | // this can be larger than the tensor size if the buffer type allocates extra memory | ||
| 61 | // ref: https://github.com/ggml-org/llama.cpp/pull/15966 | ||
| 62 | mr = { | ||
| 63 | /*.pb =*/ (uint64_t) tensor->buffer, | ||
| 64 | /*.p0 =*/ (uint64_t) tensor->data, | ||
| 65 | /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor), | ||
| 66 | /*.pt =*/ pt, | ||
| 67 | }; | ||
| 68 | } else { | ||
| 69 | // otherwise, the pointer address is used as an unique id of the memory ranges | ||
| 70 | // that the tensor will be using when it is allocated | ||
| 71 | mr = { | ||
| 72 | /*.pb =*/ (uint64_t) tensor, | ||
| 73 | /*.p0 =*/ 0, // | ||
| 74 | /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used | ||
| 75 | /*.pt =*/ pt, | ||
| 76 | }; | ||
| 77 | }; | ||
| 78 | |||
| 79 | return mr; | ||
| 80 | } | ||
| 81 | |||
| 82 | static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) { | ||
| 83 | return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC); | ||
| 84 | } | ||
| 85 | |||
| 86 | static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) { | ||
| 87 | return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST); | ||
| 88 | } | ||
| 89 | |||
| 90 | static bool ggml_mem_ranges_add_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { | ||
| 91 | GGML_ASSERT(tensor); | ||
| 92 | |||
| 93 | ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); | ||
| 94 | |||
| 95 | if (mrs->debug > 2) { | ||
| 96 | GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); | ||
| 97 | } | ||
| 98 | |||
| 99 | return ggml_mem_ranges_add(mrs, mr); | ||
| 100 | } | ||
| 101 | |||
| 102 | static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { | ||
| 103 | GGML_ASSERT(tensor); | ||
| 104 | |||
| 105 | ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); | ||
| 106 | |||
| 107 | if (mrs->debug > 2) { | ||
| 108 | GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1); | ||
| 109 | } | ||
| 110 | |||
| 111 | return ggml_mem_ranges_add(mrs, mr); | ||
| 112 | } | ||
| 113 | |||
| 114 | bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { | ||
| 115 | for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| 116 | if (tensor->src[i]) { | ||
| 117 | ggml_mem_ranges_add_src(mrs, tensor->src[i]); | ||
| 118 | } | ||
| 119 | } | ||
| 120 | |||
| 121 | return ggml_mem_ranges_add_dst(mrs, tensor); | ||
| 122 | } | ||
| 123 | |||
| 124 | static bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, ggml_mem_range mr) { | ||
| 125 | for (size_t i = 0; i < mrs->ranges.size(); i++) { | ||
| 126 | const auto & cmp = mrs->ranges[i]; | ||
| 127 | |||
| 128 | // two memory ranges cannot intersect if they are in different buffers | ||
| 129 | if (mr.pb != cmp.pb) { | ||
| 130 | continue; | ||
| 131 | } | ||
| 132 | |||
| 133 | // intersecting source ranges are allowed | ||
| 134 | if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) { | ||
| 135 | continue; | ||
| 136 | } | ||
| 137 | |||
| 138 | if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) { | ||
| 139 | if (mrs->debug > 2) { | ||
| 140 | GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n", | ||
| 141 | __func__, | ||
| 142 | mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", | ||
| 143 | mr.pb, mr.p0, mr.p1, | ||
| 144 | cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", | ||
| 145 | cmp.pb, cmp.p0, cmp.p1); | ||
| 146 | } | ||
| 147 | |||
| 148 | return false; | ||
| 149 | } | ||
| 150 | } | ||
| 151 | |||
| 152 | return true; | ||
| 153 | } | ||
| 154 | |||
| 155 | static bool ggml_mem_ranges_check_src(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { | ||
| 156 | GGML_ASSERT(tensor); | ||
| 157 | |||
| 158 | ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor); | ||
| 159 | |||
| 160 | const bool res = ggml_mem_ranges_check(mrs, mr); | ||
| 161 | |||
| 162 | return res; | ||
| 163 | } | ||
| 164 | |||
| 165 | static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { | ||
| 166 | GGML_ASSERT(tensor); | ||
| 167 | |||
| 168 | ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor); | ||
| 169 | |||
| 170 | const bool res = ggml_mem_ranges_check(mrs, mr); | ||
| 171 | |||
| 172 | return res; | ||
| 173 | } | ||
| 174 | |||
| 175 | bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { | ||
| 176 | for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| 177 | if (tensor->src[i]) { | ||
| 178 | if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) { | ||
| 179 | return false; | ||
| 180 | } | ||
| 181 | } | ||
| 182 | } | ||
| 183 | |||
| 184 | return ggml_mem_ranges_check_dst(mrs, tensor); | ||
| 185 | } | ||
| 186 | |||
| 187 | struct node_info { | ||
| 188 | ggml_tensor * node; | ||
| 189 | |||
| 190 | std::vector<ggml_tensor *> fused; | ||
| 191 | |||
| 192 | ggml_op op() const { | ||
| 193 | return node->op; | ||
| 194 | } | ||
| 195 | |||
| 196 | const ggml_tensor * dst() const { | ||
| 197 | return fused.empty() ? node : fused.back(); | ||
| 198 | } | ||
| 199 | |||
| 200 | bool is_empty() const { | ||
| 201 | return ggml_op_is_empty(node->op); | ||
| 202 | } | ||
| 203 | |||
| 204 | void add_fused(ggml_tensor * t) { | ||
| 205 | fused.push_back(t); | ||
| 206 | } | ||
| 207 | }; | ||
| 208 | |||
| 209 | static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) { | ||
| 210 | // helper to add node src and dst ranges | ||
| 211 | const auto & h_add = [](ggml_mem_ranges_t mrs, const node_info & node) { | ||
| 212 | for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| 213 | if (node.node->src[i]) { | ||
| 214 | if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) { | ||
| 215 | return false; | ||
| 216 | } | ||
| 217 | } | ||
| 218 | } | ||
| 219 | |||
| 220 | // keep track of the sources of the fused nodes as well | ||
| 221 | for (const auto * fused : node.fused) { | ||
| 222 | for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| 223 | if (fused->src[i]) { | ||
| 224 | if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) { | ||
| 225 | return false; | ||
| 226 | } | ||
| 227 | } | ||
| 228 | } | ||
| 229 | } | ||
| 230 | |||
| 231 | return ggml_mem_ranges_add_dst(mrs, node.dst()); | ||
| 232 | }; | ||
| 233 | |||
| 234 | // helper to check if a node can run concurrently with the existing set of nodes | ||
| 235 | const auto & h_check = [](ggml_mem_ranges_t mrs, const node_info & node) { | ||
| 236 | for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| 237 | if (node.node->src[i]) { | ||
| 238 | if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) { | ||
| 239 | return false; | ||
| 240 | } | ||
| 241 | } | ||
| 242 | } | ||
| 243 | |||
| 244 | for (const auto * fused : node.fused) { | ||
| 245 | for (int i = 0; i < GGML_MAX_SRC; i++) { | ||
| 246 | if (fused->src[i]) { | ||
| 247 | if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) { | ||
| 248 | return false; | ||
| 249 | } | ||
| 250 | } | ||
| 251 | } | ||
| 252 | } | ||
| 253 | |||
| 254 | return ggml_mem_ranges_check_dst(mrs, node.dst()); | ||
| 255 | }; | ||
| 256 | |||
| 257 | // perform reorders only across these types of ops | ||
| 258 | // can be expanded when needed | ||
| 259 | const auto & h_safe = [](ggml_op op) { | ||
| 260 | switch (op) { | ||
| 261 | case GGML_OP_MUL_MAT: | ||
| 262 | case GGML_OP_MUL_MAT_ID: | ||
| 263 | case GGML_OP_ROPE: | ||
| 264 | case GGML_OP_NORM: | ||
| 265 | case GGML_OP_RMS_NORM: | ||
| 266 | case GGML_OP_GROUP_NORM: | ||
| 267 | case GGML_OP_SUM_ROWS: | ||
| 268 | case GGML_OP_MUL: | ||
| 269 | case GGML_OP_ADD: | ||
| 270 | case GGML_OP_DIV: | ||
| 271 | case GGML_OP_GLU: | ||
| 272 | case GGML_OP_SCALE: | ||
| 273 | case GGML_OP_GET_ROWS: | ||
| 274 | case GGML_OP_CPY: | ||
| 275 | case GGML_OP_SET_ROWS: | ||
| 276 | return true; | ||
| 277 | default: | ||
| 278 | return ggml_op_is_empty(op); | ||
| 279 | } | ||
| 280 | }; | ||
| 281 | |||
| 282 | const int n = nodes.size(); | ||
| 283 | |||
| 284 | std::vector<int> res; | ||
| 285 | res.reserve(n); | ||
| 286 | |||
| 287 | std::vector<bool> used(n, false); | ||
| 288 | |||
| 289 | // the memory ranges for the set of currently concurrent nodes | ||
| 290 | ggml_mem_ranges_t mrs0 = ggml_mem_ranges_init(0); | ||
| 291 | |||
| 292 | // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder | ||
| 293 | ggml_mem_ranges_t mrs1 = ggml_mem_ranges_init(0); | ||
| 294 | |||
| 295 | for (int i0 = 0; i0 < n; i0++) { | ||
| 296 | if (used[i0]) { | ||
| 297 | continue; | ||
| 298 | } | ||
| 299 | |||
| 300 | const auto & node0 = nodes[i0]; | ||
| 301 | |||
| 302 | // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0) | ||
| 303 | // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0 | ||
| 304 | // | ||
| 305 | // note: we can always add empty nodes to the concurrent set as they don't read nor write anything | ||
| 306 | if (!node0.is_empty() && !h_check(mrs0, node0)) { | ||
| 307 | // this will hold the set of memory ranges from the nodes that haven't been processed yet | ||
| 308 | // if a node is not concurrent with this set, we cannot reorder it | ||
| 309 | ggml_mem_ranges_reset(mrs1); | ||
| 310 | |||
| 311 | // initialize it with the current node | ||
| 312 | h_add(mrs1, node0); | ||
| 313 | |||
| 314 | // that many nodes forward to search for a concurrent node | ||
| 315 | constexpr int N_FORWARD = 8; | ||
| 316 | |||
| 317 | for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { | ||
| 318 | if (used[i1]) { | ||
| 319 | continue; | ||
| 320 | } | ||
| 321 | |||
| 322 | const auto & node1 = nodes[i1]; | ||
| 323 | |||
| 324 | // disallow reordering of certain ops | ||
| 325 | if (!h_safe(node1.op())) { | ||
| 326 | break; | ||
| 327 | } | ||
| 328 | |||
| 329 | const bool is_empty = node1.is_empty(); | ||
| 330 | |||
| 331 | // to reorder a node and add it to the concurrent set, it has to be: | ||
| 332 | // + empty or concurrent with all nodes in the existing concurrent set (mrs0) | ||
| 333 | // + concurrent with all nodes prior to it that haven't been processed yet (mrs1) | ||
| 334 | if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) { | ||
| 335 | // add the node to the existing concurrent set (i.e. reorder it for early execution) | ||
| 336 | h_add(mrs0, node1); | ||
| 337 | res.push_back(i1); | ||
| 338 | |||
| 339 | // mark as used, so we skip re-processing it later | ||
| 340 | used[i1] = true; | ||
| 341 | } else { | ||
| 342 | // expand the set of nodes that haven't been processed yet | ||
| 343 | h_add(mrs1, node1); | ||
| 344 | } | ||
| 345 | } | ||
| 346 | |||
| 347 | // finalize the concurrent set and begin a new one | ||
| 348 | ggml_mem_ranges_reset(mrs0); | ||
| 349 | } | ||
| 350 | |||
| 351 | // expand the concurrent set with the current node | ||
| 352 | { | ||
| 353 | h_add(mrs0, node0); | ||
| 354 | res.push_back(i0); | ||
| 355 | } | ||
| 356 | } | ||
| 357 | |||
| 358 | ggml_mem_ranges_free(mrs0); | ||
| 359 | ggml_mem_ranges_free(mrs1); | ||
| 360 | |||
| 361 | return res; | ||
| 362 | } | ||
| 363 | |||
| 364 | void ggml_graph_optimize(ggml_cgraph * gf) { | ||
| 365 | constexpr int MAX_FUSE = 16; | ||
| 366 | |||
| 367 | const int n = gf->n_nodes; | ||
| 368 | |||
| 369 | enum ggml_op ops[MAX_FUSE]; | ||
| 370 | |||
| 371 | std::vector<node_info> nodes; | ||
| 372 | nodes.reserve(gf->n_nodes); | ||
| 373 | |||
| 374 | // fuse nodes: | ||
| 375 | // we don't want to make reorders that break fusing, so we first pack all fusable tensors | ||
| 376 | // and perform the reorder over the fused nodes. after the reorder is done, we unfuse | ||
| 377 | for (int i = 0; i < n; i++) { | ||
| 378 | node_info node = { | ||
| 379 | /*.node =*/ gf->nodes[i], | ||
| 380 | /*.fused =*/ {}, | ||
| 381 | }; | ||
| 382 | |||
| 383 | // fuse only ops that start with these operations | ||
| 384 | // can be expanded when needed | ||
| 385 | if (node.op() == GGML_OP_ADD || | ||
| 386 | node.op() == GGML_OP_NORM || | ||
| 387 | node.op() == GGML_OP_RMS_NORM) { | ||
| 388 | ops[0] = node.op(); | ||
| 389 | |||
| 390 | int f = i + 1; | ||
| 391 | while (f < n && f < i + MAX_FUSE) { | ||
| 392 | // conservatively allow fusing only these ops | ||
| 393 | // can be expanded when needed | ||
| 394 | if (gf->nodes[f]->op != GGML_OP_ADD && | ||
| 395 | gf->nodes[f]->op != GGML_OP_MUL && | ||
| 396 | gf->nodes[f]->op != GGML_OP_NORM && | ||
| 397 | gf->nodes[f]->op != GGML_OP_RMS_NORM) { | ||
| 398 | break; | ||
| 399 | } | ||
| 400 | ops[f - i] = gf->nodes[f]->op; | ||
| 401 | f++; | ||
| 402 | } | ||
| 403 | |||
| 404 | f -= i; | ||
| 405 | for (; f > 1; f--) { | ||
| 406 | if (ggml_can_fuse(gf, i, ops, f)) { | ||
| 407 | break; | ||
| 408 | } | ||
| 409 | } | ||
| 410 | |||
| 411 | // add the fused tensors into the node info so we can unfuse them later | ||
| 412 | for (int k = 1; k < f; k++) { | ||
| 413 | ++i; | ||
| 414 | |||
| 415 | // the .dst() becomes the last fused tensor | ||
| 416 | node.add_fused(gf->nodes[i]); | ||
| 417 | } | ||
| 418 | } | ||
| 419 | |||
| 420 | nodes.push_back(std::move(node)); | ||
| 421 | } | ||
| 422 | |||
| 423 | #if 1 | ||
| 424 | // reorder to improve concurrency | ||
| 425 | const auto order = ggml_metal_graph_optimize_reorder(nodes); | ||
| 426 | #else | ||
| 427 | std::vector<int> order(nodes.size()); | ||
| 428 | for (size_t i = 0; i < nodes.size(); i++) { | ||
| 429 | order[i] = i; | ||
| 430 | } | ||
| 431 | #endif | ||
| 432 | |||
| 433 | // unfuse | ||
| 434 | { | ||
| 435 | int j = 0; | ||
| 436 | for (const auto i : order) { | ||
| 437 | const auto & node = nodes[i]; | ||
| 438 | |||
| 439 | gf->nodes[j++] = node.node; | ||
| 440 | |||
| 441 | for (auto * fused : node.fused) { | ||
| 442 | gf->nodes[j++] = fused; | ||
| 443 | } | ||
| 444 | } | ||
| 445 | } | ||
| 446 | } | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h new file mode 100644 index 0000000..3acbc6a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-common.h | |||
| @@ -0,0 +1,52 @@ | |||
| 1 | // helper functions for ggml-metal that are too difficult to implement in Objective-C | ||
| 2 | |||
| 3 | #pragma once | ||
| 4 | |||
| 5 | #include <stdbool.h> | ||
| 6 | |||
| 7 | #ifdef __cplusplus | ||
| 8 | extern "C" { | ||
| 9 | #endif | ||
| 10 | |||
| 11 | struct ggml_tensor; | ||
| 12 | struct ggml_cgraph; | ||
| 13 | |||
| 14 | enum ggml_mem_range_type { | ||
| 15 | MEM_RANGE_TYPE_SRC = 0, | ||
| 16 | MEM_RANGE_TYPE_DST = 1, | ||
| 17 | }; | ||
| 18 | |||
| 19 | // a helper object that can be used for reordering operations to improve concurrency | ||
| 20 | // | ||
| 21 | // the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they | ||
| 22 | // don't write to a memory that is being read by another task or written to by another task in the set | ||
| 23 | // | ||
| 24 | // with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task | ||
| 25 | // can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the | ||
| 26 | // tasks already in the set) | ||
| 27 | // | ||
| 28 | typedef struct ggml_mem_ranges * ggml_mem_ranges_t; | ||
| 29 | |||
| 30 | ggml_mem_ranges_t ggml_mem_ranges_init(int debug); | ||
| 31 | void ggml_mem_ranges_free(ggml_mem_ranges_t mrs); | ||
| 32 | |||
| 33 | // remove all ranges from the set | ||
| 34 | void ggml_mem_ranges_reset(ggml_mem_ranges_t mrs); | ||
| 35 | |||
| 36 | // add src or dst ranges to track | ||
| 37 | bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor); | ||
| 38 | |||
| 39 | // return false if: | ||
| 40 | // - new src range overlaps with any existing dst range | ||
| 41 | // - new dst range overlaps with any existing range (src or dst) | ||
| 42 | bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const struct ggml_tensor * tensor); | ||
| 43 | |||
| 44 | // reorder the nodes in the graph to improve concurrency, while respecting fusion | ||
| 45 | // | ||
| 46 | // note: this implementation is generic and not specific to metal | ||
| 47 | // if it proves to work well, we can start using it for other backends in the future | ||
| 48 | void ggml_graph_optimize(struct ggml_cgraph * gf); | ||
| 49 | |||
| 50 | #ifdef __cplusplus | ||
| 51 | } | ||
| 52 | #endif | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h new file mode 100644 index 0000000..abf4b06 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.h | |||
| @@ -0,0 +1,41 @@ | |||
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include "ggml-metal-device.h" | ||
| 4 | |||
| 5 | #ifdef __cplusplus | ||
| 6 | extern "C" { | ||
| 7 | #endif | ||
| 8 | |||
| 9 | // | ||
| 10 | // backend context | ||
| 11 | // | ||
| 12 | |||
| 13 | typedef struct ggml_metal * ggml_metal_t; | ||
| 14 | |||
| 15 | ggml_metal_t ggml_metal_init(ggml_metal_device_t dev); | ||
| 16 | void ggml_metal_free(ggml_metal_t ctx); | ||
| 17 | |||
| 18 | const char * ggml_metal_get_name(ggml_metal_t ctx); | ||
| 19 | |||
| 20 | void ggml_metal_synchronize(ggml_metal_t ctx); | ||
| 21 | |||
| 22 | void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); | ||
| 23 | void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); | ||
| 24 | bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); | ||
| 25 | |||
| 26 | enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf); | ||
| 27 | void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf); | ||
| 28 | |||
| 29 | void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev); | ||
| 30 | void ggml_metal_event_wait (ggml_metal_t ctx, ggml_metal_event_t ev); | ||
| 31 | |||
| 32 | ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx); | ||
| 33 | |||
| 34 | void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb); | ||
| 35 | void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data); | ||
| 36 | bool ggml_metal_supports_family (ggml_metal_t ctx, int family); | ||
| 37 | void ggml_metal_capture_next_compute(ggml_metal_t ctx); | ||
| 38 | |||
| 39 | #ifdef __cplusplus | ||
| 40 | } | ||
| 41 | #endif | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m new file mode 100644 index 0000000..5d3a8ce --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-context.m | |||
| @@ -0,0 +1,702 @@ | |||
| 1 | #import "ggml-metal-context.h" | ||
| 2 | |||
| 3 | #import "ggml-impl.h" | ||
| 4 | #import "ggml-backend-impl.h" | ||
| 5 | |||
| 6 | #import "ggml-metal-impl.h" | ||
| 7 | #import "ggml-metal-common.h" | ||
| 8 | #import "ggml-metal-ops.h" | ||
| 9 | |||
| 10 | #import <Foundation/Foundation.h> | ||
| 11 | |||
| 12 | #import <Metal/Metal.h> | ||
| 13 | |||
| 14 | #undef MIN | ||
| 15 | #undef MAX | ||
| 16 | #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||
| 17 | #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||
| 18 | |||
| 19 | // max number of MTLCommandBuffer used to submit a graph for processing | ||
| 20 | #define GGML_METAL_MAX_COMMAND_BUFFERS 8 | ||
| 21 | |||
| 22 | struct ggml_metal_command_buffer { | ||
| 23 | id<MTLCommandBuffer> obj; | ||
| 24 | }; | ||
| 25 | |||
| 26 | struct ggml_metal { | ||
| 27 | char name[128]; | ||
| 28 | |||
| 29 | ggml_metal_device_t dev; | ||
| 30 | ggml_metal_library_t lib; | ||
| 31 | |||
| 32 | ggml_metal_event_t ev_cpy; // for async copies | ||
| 33 | |||
| 34 | dispatch_queue_t d_queue; | ||
| 35 | |||
| 36 | // additional, inference-time compiled pipelines | ||
| 37 | ggml_metal_pipelines_t pipelines_ext; | ||
| 38 | |||
| 39 | bool use_fusion; | ||
| 40 | bool use_concurrency; | ||
| 41 | bool use_graph_optimize; | ||
| 42 | |||
| 43 | int debug_graph; | ||
| 44 | int debug_fusion; | ||
| 45 | |||
| 46 | // how many times a given op was fused | ||
| 47 | uint64_t fuse_cnt[GGML_OP_COUNT]; | ||
| 48 | |||
| 49 | // capture state | ||
| 50 | bool capture_next_compute; | ||
| 51 | bool capture_started; | ||
| 52 | |||
| 53 | id<MTLCaptureScope> capture_scope; | ||
| 54 | |||
| 55 | // command buffer state | ||
| 56 | int n_cb; // number of extra threads used to submit the command buffers | ||
| 57 | int n_nodes_0; // number of nodes submitted by the main thread | ||
| 58 | int n_nodes_1; // remaining number of nodes submitted by the n_cb threads | ||
| 59 | int n_nodes_per_cb; | ||
| 60 | |||
| 61 | struct ggml_cgraph * gf; | ||
| 62 | |||
| 63 | // the callback given to the thread pool | ||
| 64 | void (^encode_async)(size_t ith); | ||
| 65 | |||
| 66 | // n_cb command buffers + 1 used by the main thread | ||
| 67 | struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; | ||
| 68 | |||
| 69 | // extra command buffers for things like getting, setting and copying tensors | ||
| 70 | NSMutableArray * cmd_bufs_ext; | ||
| 71 | |||
| 72 | // the last command buffer queued into the Metal queue with operations relevant to the current Metal backend | ||
| 73 | id<MTLCommandBuffer> cmd_buf_last; | ||
| 74 | |||
| 75 | // abort ggml_metal_graph_compute if callback returns true | ||
| 76 | ggml_abort_callback abort_callback; | ||
| 77 | void * abort_callback_data; | ||
| 78 | }; | ||
| 79 | |||
| 80 | ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { | ||
| 81 | GGML_LOG_INFO("%s: allocating\n", __func__); | ||
| 82 | |||
| 83 | #if TARGET_OS_OSX && !GGML_METAL_NDEBUG | ||
| 84 | // Show all the Metal device instances in the system | ||
| 85 | NSArray * devices = MTLCopyAllDevices(); | ||
| 86 | for (id<MTLDevice> device in devices) { | ||
| 87 | GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]); | ||
| 88 | } | ||
| 89 | [devices release]; // since it was created by a *Copy* C method | ||
| 90 | #endif | ||
| 91 | |||
| 92 | // init context | ||
| 93 | ggml_metal_t res = calloc(1, sizeof(struct ggml_metal)); | ||
| 94 | |||
| 95 | id<MTLDevice> device = ggml_metal_device_get_obj(dev); | ||
| 96 | |||
| 97 | GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); | ||
| 98 | |||
| 99 | // TODO: would it be better to have one queue for the backend and one queue for the device? | ||
| 100 | // the graph encoders and async ops would use the backend queue while the sync ops would use the device queue? | ||
| 101 | //res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND] | ||
| 102 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(dev); | ||
| 103 | if (queue == nil) { | ||
| 104 | GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); | ||
| 105 | return NULL; | ||
| 106 | } | ||
| 107 | |||
| 108 | res->dev = dev; | ||
| 109 | res->lib = ggml_metal_device_get_library(dev); | ||
| 110 | if (res->lib == NULL) { | ||
| 111 | GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__); | ||
| 112 | GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__); | ||
| 113 | |||
| 114 | res->lib = ggml_metal_library_init(dev); | ||
| 115 | if (res->lib == NULL) { | ||
| 116 | GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__); | ||
| 117 | |||
| 118 | free(res); | ||
| 119 | |||
| 120 | return NULL; | ||
| 121 | } | ||
| 122 | } | ||
| 123 | |||
| 124 | res->ev_cpy = ggml_metal_device_event_init(dev); | ||
| 125 | |||
| 126 | const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); | ||
| 127 | |||
| 128 | snprintf(res->name, sizeof(res->name), "%s", props_dev->name); | ||
| 129 | |||
| 130 | res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); | ||
| 131 | |||
| 132 | res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil; | ||
| 133 | res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil; | ||
| 134 | |||
| 135 | { | ||
| 136 | const char * val = getenv("GGML_METAL_GRAPH_DEBUG"); | ||
| 137 | res->debug_graph = val ? atoi(val) : 0; | ||
| 138 | } | ||
| 139 | |||
| 140 | { | ||
| 141 | const char * val = getenv("GGML_METAL_FUSION_DEBUG"); | ||
| 142 | res->debug_fusion = val ? atoi(val) : 0; | ||
| 143 | } | ||
| 144 | |||
| 145 | res->use_graph_optimize = true; | ||
| 146 | |||
| 147 | if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) { | ||
| 148 | res->use_graph_optimize = false; | ||
| 149 | } | ||
| 150 | |||
| 151 | memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt)); | ||
| 152 | |||
| 153 | GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false"); | ||
| 154 | GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); | ||
| 155 | GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); | ||
| 156 | |||
| 157 | res->capture_next_compute = false; | ||
| 158 | res->capture_started = false; | ||
| 159 | res->capture_scope = nil; | ||
| 160 | |||
| 161 | res->gf = nil; | ||
| 162 | res->encode_async = nil; | ||
| 163 | for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { | ||
| 164 | res->cmd_bufs[i].obj = nil; | ||
| 165 | } | ||
| 166 | |||
| 167 | res->cmd_bufs_ext = [[NSMutableArray alloc] init]; | ||
| 168 | |||
| 169 | res->cmd_buf_last = nil; | ||
| 170 | |||
| 171 | res->pipelines_ext = ggml_metal_pipelines_init(); | ||
| 172 | |||
| 173 | return res; | ||
| 174 | } | ||
| 175 | |||
| 176 | void ggml_metal_free(ggml_metal_t ctx) { | ||
| 177 | GGML_LOG_INFO("%s: deallocating\n", __func__); | ||
| 178 | |||
| 179 | for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { | ||
| 180 | if (ctx->cmd_bufs[i].obj) { | ||
| 181 | [ctx->cmd_bufs[i].obj release]; | ||
| 182 | } | ||
| 183 | } | ||
| 184 | |||
| 185 | for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) { | ||
| 186 | if (ctx->cmd_bufs_ext[i]) { | ||
| 187 | [ctx->cmd_bufs_ext[i] release]; | ||
| 188 | } | ||
| 189 | } | ||
| 190 | |||
| 191 | [ctx->cmd_bufs_ext removeAllObjects]; | ||
| 192 | [ctx->cmd_bufs_ext release]; | ||
| 193 | |||
| 194 | if (ctx->pipelines_ext) { | ||
| 195 | ggml_metal_pipelines_free(ctx->pipelines_ext); | ||
| 196 | ctx->pipelines_ext = nil; | ||
| 197 | } | ||
| 198 | |||
| 199 | if (ctx->debug_fusion > 0) { | ||
| 200 | GGML_LOG_DEBUG("%s: fusion stats:\n", __func__); | ||
| 201 | for (int i = 0; i < GGML_OP_COUNT; i++) { | ||
| 202 | if (ctx->fuse_cnt[i] == 0) { | ||
| 203 | continue; | ||
| 204 | } | ||
| 205 | |||
| 206 | // note: cannot use ggml_log here | ||
| 207 | GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]); | ||
| 208 | } | ||
| 209 | } | ||
| 210 | |||
| 211 | Block_release(ctx->encode_async); | ||
| 212 | |||
| 213 | //[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND] | ||
| 214 | |||
| 215 | dispatch_release(ctx->d_queue); | ||
| 216 | |||
| 217 | ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy); | ||
| 218 | |||
| 219 | free(ctx); | ||
| 220 | } | ||
| 221 | |||
| 222 | const char * ggml_metal_get_name(ggml_metal_t ctx) { | ||
| 223 | return ctx->name; | ||
| 224 | } | ||
| 225 | |||
| 226 | void ggml_metal_synchronize(ggml_metal_t ctx) { | ||
| 227 | // wait for any backend operations to finish | ||
| 228 | if (ctx->cmd_buf_last) { | ||
| 229 | [ctx->cmd_buf_last waitUntilCompleted]; | ||
| 230 | ctx->cmd_buf_last = nil; | ||
| 231 | } | ||
| 232 | |||
| 233 | // check status of all command buffers | ||
| 234 | { | ||
| 235 | const int n_cb = ctx->n_cb; | ||
| 236 | |||
| 237 | for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) { | ||
| 238 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj; | ||
| 239 | if (!cmd_buf) { | ||
| 240 | continue; | ||
| 241 | } | ||
| 242 | |||
| 243 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 244 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 245 | GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status); | ||
| 246 | if (status == MTLCommandBufferStatusError) { | ||
| 247 | GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 248 | } | ||
| 249 | GGML_ABORT("fatal error"); | ||
| 250 | } | ||
| 251 | } | ||
| 252 | } | ||
| 253 | |||
| 254 | // release any completed extra command buffers | ||
| 255 | if (ctx->cmd_bufs_ext.count > 0) { | ||
| 256 | for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) { | ||
| 257 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i]; | ||
| 258 | |||
| 259 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 260 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 261 | GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status); | ||
| 262 | if (status == MTLCommandBufferStatusError) { | ||
| 263 | GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 264 | } | ||
| 265 | GGML_ABORT("fatal error"); | ||
| 266 | } | ||
| 267 | |||
| 268 | [cmd_buf release]; | ||
| 269 | } | ||
| 270 | |||
| 271 | [ctx->cmd_bufs_ext removeAllObjects]; | ||
| 272 | } | ||
| 273 | } | ||
| 274 | |||
| 275 | static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) { | ||
| 276 | if (!t) { | ||
| 277 | return (struct ggml_metal_buffer_id) { nil, 0 }; | ||
| 278 | } | ||
| 279 | |||
| 280 | ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; | ||
| 281 | |||
| 282 | return ggml_metal_buffer_get_id(buffer->context, t); | ||
| 283 | } | ||
| 284 | |||
| 285 | void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { | ||
| 286 | @autoreleasepool { | ||
| 287 | // wrap the source data into a Metal buffer | ||
| 288 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 289 | id<MTLBuffer> buf_src = [device newBufferWithBytes:data | ||
| 290 | length:size | ||
| 291 | options:MTLResourceStorageModeShared]; | ||
| 292 | |||
| 293 | GGML_ASSERT(buf_src); | ||
| 294 | |||
| 295 | struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor); | ||
| 296 | if (bid_dst.metal == nil) { | ||
| 297 | GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); | ||
| 298 | } | ||
| 299 | |||
| 300 | bid_dst.offs += offset; | ||
| 301 | |||
| 302 | // queue the copy operation into the queue of the Metal context | ||
| 303 | // this will be queued at the end, after any currently ongoing GPU operations | ||
| 304 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 305 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 306 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 307 | |||
| 308 | [encoder copyFromBuffer:buf_src | ||
| 309 | sourceOffset:0 | ||
| 310 | toBuffer:bid_dst.metal | ||
| 311 | destinationOffset:bid_dst.offs | ||
| 312 | size:size]; | ||
| 313 | |||
| 314 | [encoder endEncoding]; | ||
| 315 | [cmd_buf commit]; | ||
| 316 | [buf_src release]; | ||
| 317 | |||
| 318 | // do not wait here for completion | ||
| 319 | //[cmd_buf waitUntilCompleted]; | ||
| 320 | |||
| 321 | // instead, remember a reference to the command buffer and wait for it later if needed | ||
| 322 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 323 | ctx->cmd_buf_last = cmd_buf; | ||
| 324 | |||
| 325 | [cmd_buf retain]; | ||
| 326 | } | ||
| 327 | } | ||
| 328 | |||
| 329 | void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { | ||
| 330 | @autoreleasepool { | ||
| 331 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 332 | id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data | ||
| 333 | length:size | ||
| 334 | options:MTLResourceStorageModeShared | ||
| 335 | deallocator:nil]; | ||
| 336 | |||
| 337 | GGML_ASSERT(buf_dst); | ||
| 338 | |||
| 339 | struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor); | ||
| 340 | if (bid_src.metal == nil) { | ||
| 341 | GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name); | ||
| 342 | } | ||
| 343 | |||
| 344 | bid_src.offs += offset; | ||
| 345 | |||
| 346 | // queue the copy operation into the queue of the Metal context | ||
| 347 | // this will be queued at the end, after any currently ongoing GPU operations | ||
| 348 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 349 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 350 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 351 | |||
| 352 | [encoder copyFromBuffer:bid_src.metal | ||
| 353 | sourceOffset:bid_src.offs | ||
| 354 | toBuffer:buf_dst | ||
| 355 | destinationOffset:0 | ||
| 356 | size:size]; | ||
| 357 | |||
| 358 | [encoder endEncoding]; | ||
| 359 | [cmd_buf commit]; | ||
| 360 | [buf_dst release]; | ||
| 361 | |||
| 362 | // do not wait here for completion | ||
| 363 | //[cmd_buf waitUntilCompleted]; | ||
| 364 | |||
| 365 | // instead, remember a reference to the command buffer and wait for it later if needed | ||
| 366 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 367 | ctx->cmd_buf_last = cmd_buf; | ||
| 368 | |||
| 369 | [cmd_buf retain]; | ||
| 370 | } | ||
| 371 | } | ||
| 372 | |||
| 373 | bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { | ||
| 374 | @autoreleasepool { | ||
| 375 | struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src); | ||
| 376 | struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst); | ||
| 377 | |||
| 378 | if (bid_src.metal == nil || bid_dst.metal == nil) { | ||
| 379 | return false; | ||
| 380 | } | ||
| 381 | |||
| 382 | // queue the copy operation into the Metal context | ||
| 383 | // this will be queued at the end, after any currently ongoing GPU operations | ||
| 384 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx_src->dev); | ||
| 385 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 386 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 387 | |||
| 388 | [encoder copyFromBuffer:bid_src.metal | ||
| 389 | sourceOffset:bid_src.offs | ||
| 390 | toBuffer:bid_dst.metal | ||
| 391 | destinationOffset:bid_dst.offs | ||
| 392 | size:ggml_nbytes(src)]; | ||
| 393 | |||
| 394 | [encoder endEncoding]; | ||
| 395 | |||
| 396 | ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src); | ||
| 397 | ggml_metal_event_encode_signal(ev_cpy, cmd_buf); | ||
| 398 | |||
| 399 | [cmd_buf commit]; | ||
| 400 | |||
| 401 | // do not wait here for completion | ||
| 402 | //[cmd_buf waitUntilCompleted]; | ||
| 403 | |||
| 404 | // instead, remember a reference to the command buffer and wait for it later if needed | ||
| 405 | [ctx_src->cmd_bufs_ext addObject:cmd_buf]; | ||
| 406 | ctx_src->cmd_buf_last = cmd_buf; | ||
| 407 | |||
| 408 | [cmd_buf retain]; | ||
| 409 | |||
| 410 | ggml_metal_event_wait(ctx_dst, ev_cpy); | ||
| 411 | |||
| 412 | return true; | ||
| 413 | } | ||
| 414 | } | ||
| 415 | |||
| 416 | enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { | ||
| 417 | // number of nodes encoded by the main thread (empirically determined) | ||
| 418 | const int n_main = MAX(64, 0.1*gf->n_nodes); | ||
| 419 | |||
| 420 | // number of threads in addition to the main thread | ||
| 421 | const int n_cb = ctx->n_cb; | ||
| 422 | |||
| 423 | // keep the memory wired | ||
| 424 | ggml_metal_device_rsets_keep_alive(ctx->dev); | ||
| 425 | |||
| 426 | // submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them | ||
| 427 | // the first n_nodes_0 are encoded and submitted for processing directly by the calling thread | ||
| 428 | // while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes | ||
| 429 | // each thread creates it's own command buffer and enqueues the ops in parallel | ||
| 430 | // | ||
| 431 | // tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2 | ||
| 432 | |||
| 433 | @autoreleasepool { | ||
| 434 | ctx->gf = gf; | ||
| 435 | |||
| 436 | ctx->n_nodes_0 = MIN(n_main, gf->n_nodes); | ||
| 437 | ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0; | ||
| 438 | |||
| 439 | ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; | ||
| 440 | |||
| 441 | const bool use_capture = ctx->capture_next_compute; | ||
| 442 | if (use_capture) { | ||
| 443 | ctx->capture_next_compute = false; | ||
| 444 | |||
| 445 | // make sure all previous computations have finished before starting the capture | ||
| 446 | if (ctx->cmd_buf_last) { | ||
| 447 | [ctx->cmd_buf_last waitUntilCompleted]; | ||
| 448 | ctx->cmd_buf_last = nil; | ||
| 449 | } | ||
| 450 | |||
| 451 | if (!ctx->capture_started) { | ||
| 452 | // create capture scope | ||
| 453 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 454 | ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device]; | ||
| 455 | |||
| 456 | MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; | ||
| 457 | descriptor.captureObject = ctx->capture_scope; | ||
| 458 | descriptor.destination = MTLCaptureDestinationGPUTraceDocument; | ||
| 459 | descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; | ||
| 460 | |||
| 461 | NSError * error = nil; | ||
| 462 | if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { | ||
| 463 | GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]); | ||
| 464 | } else { | ||
| 465 | [ctx->capture_scope beginScope]; | ||
| 466 | ctx->capture_started = true; | ||
| 467 | } | ||
| 468 | } | ||
| 469 | } | ||
| 470 | |||
| 471 | // short-hand | ||
| 472 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 473 | |||
| 474 | // the main thread commits the first few commands immediately | ||
| 475 | // cmd_buf[n_cb] | ||
| 476 | { | ||
| 477 | id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences]; | ||
| 478 | [cmd_buf retain]; | ||
| 479 | |||
| 480 | if (ctx->cmd_bufs[n_cb].obj) { | ||
| 481 | [ctx->cmd_bufs[n_cb].obj release]; | ||
| 482 | } | ||
| 483 | ctx->cmd_bufs[n_cb].obj = cmd_buf; | ||
| 484 | |||
| 485 | [cmd_buf enqueue]; | ||
| 486 | |||
| 487 | ctx->encode_async(n_cb); | ||
| 488 | } | ||
| 489 | |||
| 490 | // remember the command buffer for the next iteration | ||
| 491 | ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj; | ||
| 492 | |||
| 493 | // prepare the rest of the command buffers asynchronously (optional) | ||
| 494 | // cmd_buf[0.. n_cb) | ||
| 495 | for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { | ||
| 496 | id<MTLCommandBuffer> cmd_buf = [queue commandBufferWithUnretainedReferences]; | ||
| 497 | [cmd_buf retain]; | ||
| 498 | |||
| 499 | if (ctx->cmd_bufs[cb_idx].obj) { | ||
| 500 | [ctx->cmd_bufs[cb_idx].obj release]; | ||
| 501 | } | ||
| 502 | ctx->cmd_bufs[cb_idx].obj = cmd_buf; | ||
| 503 | |||
| 504 | // always enqueue the first two command buffers | ||
| 505 | // enqueue all of the command buffers if we don't need to abort | ||
| 506 | if (cb_idx < 2 || ctx->abort_callback == NULL) { | ||
| 507 | [cmd_buf enqueue]; | ||
| 508 | |||
| 509 | // update the pointer to the last queued command buffer | ||
| 510 | // this is needed to implement synchronize() | ||
| 511 | ctx->cmd_buf_last = cmd_buf; | ||
| 512 | } | ||
| 513 | } | ||
| 514 | |||
| 515 | dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); | ||
| 516 | |||
| 517 | // for debugging: block until graph is computed | ||
| 518 | //[ctx->cmd_buf_last waitUntilCompleted]; | ||
| 519 | |||
| 520 | // enter here only when capturing in order to wait for all computation to finish | ||
| 521 | // otherwise, we leave the graph to compute asynchronously | ||
| 522 | if (!use_capture && ctx->capture_started) { | ||
| 523 | // wait for completion and check status of each command buffer | ||
| 524 | // needed to detect if the device ran out-of-memory for example (#1881) | ||
| 525 | { | ||
| 526 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj; | ||
| 527 | [cmd_buf waitUntilCompleted]; | ||
| 528 | |||
| 529 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 530 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 531 | GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); | ||
| 532 | if (status == MTLCommandBufferStatusError) { | ||
| 533 | GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 534 | } | ||
| 535 | |||
| 536 | return GGML_STATUS_FAILED; | ||
| 537 | } | ||
| 538 | } | ||
| 539 | |||
| 540 | for (int i = 0; i < n_cb; ++i) { | ||
| 541 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj; | ||
| 542 | [cmd_buf waitUntilCompleted]; | ||
| 543 | |||
| 544 | MTLCommandBufferStatus status = [cmd_buf status]; | ||
| 545 | if (status != MTLCommandBufferStatusCompleted) { | ||
| 546 | GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); | ||
| 547 | if (status == MTLCommandBufferStatusError) { | ||
| 548 | GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); | ||
| 549 | } | ||
| 550 | |||
| 551 | return GGML_STATUS_FAILED; | ||
| 552 | } | ||
| 553 | |||
| 554 | id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); | ||
| 555 | if (!next_buffer) { | ||
| 556 | continue; | ||
| 557 | } | ||
| 558 | |||
| 559 | const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); | ||
| 560 | if (next_queued) { | ||
| 561 | continue; | ||
| 562 | } | ||
| 563 | |||
| 564 | if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { | ||
| 565 | GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); | ||
| 566 | return GGML_STATUS_ABORTED; | ||
| 567 | } | ||
| 568 | |||
| 569 | [next_buffer commit]; | ||
| 570 | } | ||
| 571 | |||
| 572 | [ctx->capture_scope endScope]; | ||
| 573 | [[MTLCaptureManager sharedCaptureManager] stopCapture]; | ||
| 574 | } | ||
| 575 | } | ||
| 576 | |||
| 577 | return GGML_STATUS_SUCCESS; | ||
| 578 | } | ||
| 579 | |||
| 580 | void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) { | ||
| 581 | //const int64_t t_start = ggml_time_us(); | ||
| 582 | |||
| 583 | if (ctx->use_graph_optimize) { | ||
| 584 | ggml_graph_optimize(gf); | ||
| 585 | } | ||
| 586 | |||
| 587 | //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); | ||
| 588 | } | ||
| 589 | |||
| 590 | void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) { | ||
| 591 | @autoreleasepool { | ||
| 592 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 593 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 594 | |||
| 595 | ggml_metal_event_encode_signal(ev, cmd_buf); | ||
| 596 | |||
| 597 | [cmd_buf commit]; | ||
| 598 | |||
| 599 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 600 | ctx->cmd_buf_last = cmd_buf; | ||
| 601 | |||
| 602 | [cmd_buf retain]; | ||
| 603 | } | ||
| 604 | } | ||
| 605 | |||
| 606 | void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) { | ||
| 607 | @autoreleasepool { | ||
| 608 | id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); | ||
| 609 | id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; | ||
| 610 | |||
| 611 | ggml_metal_event_encode_wait(ev, cmd_buf); | ||
| 612 | |||
| 613 | [cmd_buf commit]; | ||
| 614 | |||
| 615 | [ctx->cmd_bufs_ext addObject:cmd_buf]; | ||
| 616 | ctx->cmd_buf_last = cmd_buf; | ||
| 617 | |||
| 618 | [cmd_buf retain]; | ||
| 619 | } | ||
| 620 | } | ||
| 621 | |||
| 622 | ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) { | ||
| 623 | return ctx->ev_cpy; | ||
| 624 | } | ||
| 625 | |||
| 626 | void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { | ||
| 627 | if (ctx->n_cb != n_cb) { | ||
| 628 | ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); | ||
| 629 | |||
| 630 | if (ctx->n_cb > 2) { | ||
| 631 | GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb); | ||
| 632 | } | ||
| 633 | } | ||
| 634 | |||
| 635 | if (ctx->encode_async) { | ||
| 636 | Block_release(ctx->encode_async); | ||
| 637 | } | ||
| 638 | |||
| 639 | ctx->encode_async = Block_copy(^(size_t iter) { | ||
| 640 | const int cb_idx = iter; | ||
| 641 | const int n_cb_l = ctx->n_cb; | ||
| 642 | |||
| 643 | const int n_nodes_0 = ctx->n_nodes_0; | ||
| 644 | const int n_nodes_1 = ctx->n_nodes_1; | ||
| 645 | |||
| 646 | const int n_nodes_per_cb = ctx->n_nodes_per_cb; | ||
| 647 | |||
| 648 | int idx_start = 0; | ||
| 649 | int idx_end = n_nodes_0; | ||
| 650 | |||
| 651 | if (cb_idx < n_cb_l) { | ||
| 652 | idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb); | ||
| 653 | idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1)); | ||
| 654 | } | ||
| 655 | |||
| 656 | id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj; | ||
| 657 | |||
| 658 | ggml_metal_op_t ctx_op = ggml_metal_op_init( | ||
| 659 | ctx->dev, | ||
| 660 | cmd_buf, | ||
| 661 | ctx->gf, | ||
| 662 | idx_start, | ||
| 663 | idx_end, | ||
| 664 | ctx->use_fusion, | ||
| 665 | ctx->use_concurrency, | ||
| 666 | ctx->capture_next_compute, | ||
| 667 | ctx->debug_graph, | ||
| 668 | ctx->debug_fusion); | ||
| 669 | |||
| 670 | for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) { | ||
| 671 | const int res = ggml_metal_op_encode(ctx_op, idx); | ||
| 672 | if (res == 0) { | ||
| 673 | break; | ||
| 674 | } | ||
| 675 | |||
| 676 | idx += res - 1; | ||
| 677 | } | ||
| 678 | |||
| 679 | ggml_metal_op_free(ctx_op); | ||
| 680 | |||
| 681 | if (cb_idx < 2 || ctx->abort_callback == NULL) { | ||
| 682 | [cmd_buf commit]; | ||
| 683 | } | ||
| 684 | }); | ||
| 685 | } | ||
| 686 | |||
| 687 | void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) { | ||
| 688 | ctx->abort_callback = abort_callback; | ||
| 689 | ctx->abort_callback_data = user_data; | ||
| 690 | } | ||
| 691 | |||
| 692 | bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { | ||
| 693 | GGML_ASSERT(ctx->dev != nil); | ||
| 694 | |||
| 695 | id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); | ||
| 696 | |||
| 697 | return [device supportsFamily:(MTLGPUFamilyApple1 + family - 1)]; | ||
| 698 | } | ||
| 699 | |||
| 700 | void ggml_metal_capture_next_compute(ggml_metal_t ctx) { | ||
| 701 | ctx->capture_next_compute = true; | ||
| 702 | } | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp new file mode 100644 index 0000000..517559d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.cpp | |||
| @@ -0,0 +1,1875 @@ | |||
| 1 | #include "ggml-metal-device.h" | ||
| 2 | |||
| 3 | #include "ggml-metal-impl.h" | ||
| 4 | |||
| 5 | #include "ggml-impl.h" | ||
| 6 | |||
| 7 | #include <cassert> | ||
| 8 | #include <memory> | ||
| 9 | #include <string> | ||
| 10 | #include <unordered_map> | ||
| 11 | |||
| 12 | struct ggml_metal_device_deleter { | ||
| 13 | void operator()(ggml_metal_device_t ctx) { | ||
| 14 | ggml_metal_device_free(ctx); | ||
| 15 | } | ||
| 16 | }; | ||
| 17 | |||
| 18 | typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr; | ||
| 19 | |||
| 20 | ggml_metal_device_t ggml_metal_device_get(int device) { | ||
| 21 | static std::vector<ggml_metal_device_ptr> devs; | ||
| 22 | |||
| 23 | devs.emplace_back(ggml_metal_device_init(device)); | ||
| 24 | |||
| 25 | return devs.back().get(); | ||
| 26 | } | ||
| 27 | |||
| 28 | struct ggml_metal_pipelines { | ||
| 29 | std::unordered_map<std::string, ggml_metal_pipeline_t> data; | ||
| 30 | }; | ||
| 31 | |||
| 32 | ggml_metal_pipelines_t ggml_metal_pipelines_init(void) { | ||
| 33 | ggml_metal_pipelines_t res = new ggml_metal_pipelines(); | ||
| 34 | |||
| 35 | return res; | ||
| 36 | } | ||
| 37 | |||
| 38 | void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) { | ||
| 39 | if (!ppls) { | ||
| 40 | return; | ||
| 41 | } | ||
| 42 | |||
| 43 | for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) { | ||
| 44 | ggml_metal_pipeline_free(it->second); | ||
| 45 | } | ||
| 46 | |||
| 47 | delete ppls; | ||
| 48 | } | ||
| 49 | |||
| 50 | void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline) { | ||
| 51 | ppls->data[name] = pipeline; | ||
| 52 | } | ||
| 53 | |||
| 54 | ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name) { | ||
| 55 | if (ppls->data.find(name) == ppls->data.end()) { | ||
| 56 | return nullptr; | ||
| 57 | } | ||
| 58 | |||
| 59 | return ppls->data[name]; | ||
| 60 | } | ||
| 61 | |||
| 62 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) { | ||
| 63 | char base[256]; | ||
| 64 | char name[256]; | ||
| 65 | |||
| 66 | const char * op_str = "undefined"; | ||
| 67 | switch (op) { | ||
| 68 | case GGML_OP_ADD_ID: op_str = "add_id"; break; | ||
| 69 | case GGML_OP_CONCAT: op_str = "concat"; break; | ||
| 70 | default: GGML_ABORT("fatal error"); | ||
| 71 | }; | ||
| 72 | |||
| 73 | snprintf(base, 256, "kernel_%s", op_str); | ||
| 74 | snprintf(name, 256, "%s", base); | ||
| 75 | |||
| 76 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 77 | if (!res.pipeline) { | ||
| 78 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 79 | } | ||
| 80 | |||
| 81 | return res; | ||
| 82 | } | ||
| 83 | |||
| 84 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) { | ||
| 85 | char base[256]; | ||
| 86 | char name[256]; | ||
| 87 | |||
| 88 | snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst)); | ||
| 89 | snprintf(name, 256, "%s", base); | ||
| 90 | |||
| 91 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 92 | if (!res.pipeline) { | ||
| 93 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 94 | } | ||
| 95 | |||
| 96 | return res; | ||
| 97 | } | ||
| 98 | |||
| 99 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { | ||
| 100 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 101 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); | ||
| 102 | |||
| 103 | const char * pool_str = "undefined"; | ||
| 104 | switch (op_pool) { | ||
| 105 | case GGML_OP_POOL_AVG: pool_str = "avg"; break; | ||
| 106 | case GGML_OP_POOL_MAX: pool_str = "max"; break; | ||
| 107 | default: GGML_ASSERT(false && "not implemented"); | ||
| 108 | }; | ||
| 109 | |||
| 110 | char base[256]; | ||
| 111 | char name[256]; | ||
| 112 | |||
| 113 | snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); | ||
| 114 | snprintf(name, sizeof(name), "%s", base); | ||
| 115 | |||
| 116 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 117 | if (!res.pipeline) { | ||
| 118 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 119 | } | ||
| 120 | |||
| 121 | return res; | ||
| 122 | } | ||
| 123 | |||
| 124 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { | ||
| 125 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 126 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); | ||
| 127 | |||
| 128 | const char * pool_str = "undefined"; | ||
| 129 | switch (op_pool) { | ||
| 130 | case GGML_OP_POOL_AVG: pool_str = "avg"; break; | ||
| 131 | case GGML_OP_POOL_MAX: pool_str = "max"; break; | ||
| 132 | default: GGML_ASSERT(false && "not implemented"); | ||
| 133 | }; | ||
| 134 | |||
| 135 | char base[256]; | ||
| 136 | char name[256]; | ||
| 137 | |||
| 138 | snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); | ||
| 139 | snprintf(name, 256, "%s", base); | ||
| 140 | |||
| 141 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 142 | if (!res.pipeline) { | ||
| 143 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 144 | } | ||
| 145 | |||
| 146 | return res; | ||
| 147 | } | ||
| 148 | |||
| 149 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) { | ||
| 150 | char base[256]; | ||
| 151 | char name[256]; | ||
| 152 | |||
| 153 | snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc)); | ||
| 154 | snprintf(name, 256, "%s", base); | ||
| 155 | |||
| 156 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 157 | if (!res.pipeline) { | ||
| 158 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 159 | } | ||
| 160 | |||
| 161 | return res; | ||
| 162 | } | ||
| 163 | |||
| 164 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) { | ||
| 165 | char base[256]; | ||
| 166 | char name[256]; | ||
| 167 | |||
| 168 | snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx)); | ||
| 169 | snprintf(name, 256, "%s", base); | ||
| 170 | |||
| 171 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 172 | if (!res.pipeline) { | ||
| 173 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 174 | } | ||
| 175 | |||
| 176 | return res; | ||
| 177 | } | ||
| 178 | |||
| 179 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 180 | char base[256]; | ||
| 181 | char name[256]; | ||
| 182 | |||
| 183 | const int n = op->src[0]->ne[0]; | ||
| 184 | |||
| 185 | snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type)); | ||
| 186 | snprintf(name, 256, "%s_n=%d", base, n); | ||
| 187 | |||
| 188 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 189 | if (!res.pipeline) { | ||
| 190 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 191 | } | ||
| 192 | |||
| 193 | res.nsg = 1; | ||
| 194 | res.smem = 0; | ||
| 195 | |||
| 196 | return res; | ||
| 197 | } | ||
| 198 | |||
| 199 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { | ||
| 200 | char base[256]; | ||
| 201 | char name[256]; | ||
| 202 | |||
| 203 | snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc)); | ||
| 204 | snprintf(name, 256, "%s", base); | ||
| 205 | |||
| 206 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 207 | if (!res.pipeline) { | ||
| 208 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 209 | } | ||
| 210 | |||
| 211 | return res; | ||
| 212 | } | ||
| 213 | |||
| 214 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 215 | char base[256]; | ||
| 216 | char name[256]; | ||
| 217 | |||
| 218 | int op_num = -1; | ||
| 219 | |||
| 220 | switch (op->op) { | ||
| 221 | case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break; | ||
| 222 | case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break; | ||
| 223 | case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break; | ||
| 224 | case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break; | ||
| 225 | case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break; | ||
| 226 | case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break; | ||
| 227 | case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break; | ||
| 228 | case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break; | ||
| 229 | case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break; | ||
| 230 | case GGML_OP_UNARY: | ||
| 231 | switch (ggml_get_unary_op(op)) { | ||
| 232 | case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break; | ||
| 233 | case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break; | ||
| 234 | case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break; | ||
| 235 | case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break; | ||
| 236 | case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break; | ||
| 237 | case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break; | ||
| 238 | case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break; | ||
| 239 | case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break; | ||
| 240 | case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break; | ||
| 241 | case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break; | ||
| 242 | case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break; | ||
| 243 | case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break; | ||
| 244 | case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break; | ||
| 245 | case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break; | ||
| 246 | case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; | ||
| 247 | case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; | ||
| 248 | case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; | ||
| 249 | default: GGML_ABORT("fatal error"); | ||
| 250 | } break; | ||
| 251 | default: GGML_ABORT("fatal error"); | ||
| 252 | }; | ||
| 253 | |||
| 254 | const char * t0_str = ggml_type_name(op->src[0]->type); | ||
| 255 | const char * t_str = ggml_type_name(op->type); | ||
| 256 | |||
| 257 | const bool is_c4 = op->src[0]->ne[0] % 4 == 0; | ||
| 258 | const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768; | ||
| 259 | |||
| 260 | snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); | ||
| 261 | snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt); | ||
| 262 | |||
| 263 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 264 | if (!res.pipeline) { | ||
| 265 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 266 | |||
| 267 | ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0); | ||
| 268 | ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1); | ||
| 269 | |||
| 270 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 271 | |||
| 272 | ggml_metal_cv_free(cv); | ||
| 273 | } | ||
| 274 | |||
| 275 | res.c4 = is_c4; | ||
| 276 | res.cnt = is_cnt; | ||
| 277 | |||
| 278 | return res; | ||
| 279 | } | ||
| 280 | |||
| 281 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 282 | GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); | ||
| 283 | |||
| 284 | char base[256]; | ||
| 285 | char name[256]; | ||
| 286 | |||
| 287 | const char * op_str = "undefined"; | ||
| 288 | switch (op->op) { | ||
| 289 | case GGML_OP_GLU: | ||
| 290 | switch (ggml_get_glu_op(op)) { | ||
| 291 | case GGML_GLU_OP_REGLU: op_str = "reglu"; break; | ||
| 292 | case GGML_GLU_OP_GEGLU: op_str = "geglu"; break; | ||
| 293 | case GGML_GLU_OP_SWIGLU: op_str = "swiglu"; break; | ||
| 294 | case GGML_GLU_OP_SWIGLU_OAI: op_str = "swiglu_oai"; break; | ||
| 295 | case GGML_GLU_OP_GEGLU_ERF: op_str = "geglu_erf"; break; | ||
| 296 | case GGML_GLU_OP_GEGLU_QUICK: op_str = "geglu_quick"; break; | ||
| 297 | default: GGML_ABORT("fatal error"); | ||
| 298 | } break; | ||
| 299 | default: GGML_ABORT("fatal error"); | ||
| 300 | }; | ||
| 301 | |||
| 302 | snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); | ||
| 303 | snprintf(name, 256, "%s", base); | ||
| 304 | |||
| 305 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 306 | if (!res.pipeline) { | ||
| 307 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 308 | } | ||
| 309 | |||
| 310 | return res; | ||
| 311 | } | ||
| 312 | |||
| 313 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 314 | assert(op->op == GGML_OP_SUM); | ||
| 315 | |||
| 316 | char base[256]; | ||
| 317 | char name[256]; | ||
| 318 | |||
| 319 | snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); | ||
| 320 | snprintf(name, 256, "%s", base); | ||
| 321 | |||
| 322 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 323 | if (!res.pipeline) { | ||
| 324 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 325 | } | ||
| 326 | |||
| 327 | return res; | ||
| 328 | } | ||
| 329 | |||
| 330 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 331 | GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); | ||
| 332 | |||
| 333 | char base[256]; | ||
| 334 | char name[256]; | ||
| 335 | |||
| 336 | const char * op_str = "undefined"; | ||
| 337 | switch (op->op) { | ||
| 338 | case GGML_OP_SUM_ROWS: | ||
| 339 | op_str = "sum_rows"; break; | ||
| 340 | case GGML_OP_MEAN: | ||
| 341 | op_str = "mean"; break; | ||
| 342 | default: GGML_ABORT("fatal error"); | ||
| 343 | }; | ||
| 344 | |||
| 345 | snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); | ||
| 346 | |||
| 347 | snprintf(name, 256, "%s", base); | ||
| 348 | |||
| 349 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 350 | if (!res.pipeline) { | ||
| 351 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 352 | } | ||
| 353 | |||
| 354 | res.smem = 32*sizeof(float); | ||
| 355 | |||
| 356 | return res; | ||
| 357 | } | ||
| 358 | |||
| 359 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 360 | GGML_ASSERT(op->op == GGML_OP_CUMSUM); | ||
| 361 | |||
| 362 | char base[256]; | ||
| 363 | char name[256]; | ||
| 364 | |||
| 365 | snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type)); | ||
| 366 | snprintf(name, 256, "%s", base); | ||
| 367 | |||
| 368 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 369 | if (!res.pipeline) { | ||
| 370 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 371 | } | ||
| 372 | |||
| 373 | return res; | ||
| 374 | } | ||
| 375 | |||
| 376 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 377 | GGML_ASSERT(op->op == GGML_OP_CUMSUM); | ||
| 378 | |||
| 379 | char base[256]; | ||
| 380 | char name[256]; | ||
| 381 | |||
| 382 | snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type)); | ||
| 383 | snprintf(name, 256, "%s", base); | ||
| 384 | |||
| 385 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 386 | if (!res.pipeline) { | ||
| 387 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 388 | } | ||
| 389 | |||
| 390 | return res; | ||
| 391 | } | ||
| 392 | |||
| 393 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 394 | GGML_ASSERT(op->op == GGML_OP_TRI); | ||
| 395 | GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); | ||
| 396 | |||
| 397 | char base[256]; | ||
| 398 | char name[256]; | ||
| 399 | |||
| 400 | const char * op_str = "tri"; | ||
| 401 | const int ttype = op->op_params[0]; | ||
| 402 | |||
| 403 | snprintf(base, 256, "kernel_%s_%s_%d", op_str, ggml_type_name(op->src[0]->type), ttype); | ||
| 404 | |||
| 405 | snprintf(name, 256, "%s", base); | ||
| 406 | |||
| 407 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 408 | if (!res.pipeline) { | ||
| 409 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 410 | } | ||
| 411 | |||
| 412 | return res; | ||
| 413 | } | ||
| 414 | |||
| 415 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 416 | GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); | ||
| 417 | |||
| 418 | char base[256]; | ||
| 419 | char name[256]; | ||
| 420 | |||
| 421 | const char * suffix = ""; | ||
| 422 | |||
| 423 | if (op->src[0]->ne[0] % 4 == 0) { | ||
| 424 | suffix = "_4"; | ||
| 425 | } | ||
| 426 | |||
| 427 | const ggml_type tsrc1 = op->src[1] ? op->src[1]->type : GGML_TYPE_F32; | ||
| 428 | |||
| 429 | snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix); | ||
| 430 | snprintf(name, 256, "%s", base); | ||
| 431 | |||
| 432 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 433 | if (!res.pipeline) { | ||
| 434 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 435 | } | ||
| 436 | |||
| 437 | res.smem = 32*sizeof(float); | ||
| 438 | |||
| 439 | return res; | ||
| 440 | } | ||
| 441 | |||
| 442 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 443 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); | ||
| 444 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 445 | |||
| 446 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 447 | GGML_ASSERT(ggml_is_contiguous(op->src[1])); | ||
| 448 | |||
| 449 | char base[256]; | ||
| 450 | char name[256]; | ||
| 451 | |||
| 452 | const char * suffix = ""; | ||
| 453 | |||
| 454 | if (op->src[1]->ne[0] % 4 == 0) { | ||
| 455 | suffix = "_4"; | ||
| 456 | } | ||
| 457 | |||
| 458 | snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); | ||
| 459 | snprintf(name, 256, "%s", base); | ||
| 460 | |||
| 461 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 462 | if (!res.pipeline) { | ||
| 463 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 464 | } | ||
| 465 | |||
| 466 | return res; | ||
| 467 | } | ||
| 468 | |||
| 469 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched(ggml_metal_library_t lib, const ggml_tensor * op, int ssm_conv_bs) { | ||
| 470 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); | ||
| 471 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 472 | |||
| 473 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 474 | GGML_ASSERT(ggml_is_contiguous(op->src[1])); | ||
| 475 | |||
| 476 | char base[256]; | ||
| 477 | char name[256]; | ||
| 478 | |||
| 479 | const char * suffix = ""; | ||
| 480 | if (op->src[1]->ne[0] % 4 == 0) { | ||
| 481 | suffix = "_4"; | ||
| 482 | } | ||
| 483 | |||
| 484 | snprintf(base, 256, "kernel_ssm_conv_%s_%s_batched%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); | ||
| 485 | snprintf(name, 256, "%s_ssm_conv_bs=%d", base, ssm_conv_bs); | ||
| 486 | |||
| 487 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 488 | if (!res.pipeline) { | ||
| 489 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 490 | |||
| 491 | ggml_metal_cv_set_int16(cv, ssm_conv_bs, FC_SSM_CONV + 0); | ||
| 492 | |||
| 493 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 494 | |||
| 495 | ggml_metal_cv_free(cv); | ||
| 496 | } | ||
| 497 | |||
| 498 | return res; | ||
| 499 | } | ||
| 500 | |||
| 501 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 502 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 503 | |||
| 504 | char base[256]; | ||
| 505 | char name[256]; | ||
| 506 | |||
| 507 | const int nsg = (ne00 + 31)/32; | ||
| 508 | |||
| 509 | snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); | ||
| 510 | snprintf(name, 256, "%s_nsg=%d", base, nsg); | ||
| 511 | |||
| 512 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 513 | if (!res.pipeline) { | ||
| 514 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 515 | } | ||
| 516 | |||
| 517 | // Shared memory layout: | ||
| 518 | // - sgptg * NW floats for partial sums (nsg * 32) | ||
| 519 | // - sgptg floats for shared_x_dt (nsg) | ||
| 520 | // - sgptg floats for shared_dA (nsg) | ||
| 521 | // Total: nsg * (32 + 2) floats | ||
| 522 | res.smem = (32 + 2)*sizeof(float)*nsg; | ||
| 523 | |||
| 524 | return res; | ||
| 525 | } | ||
| 526 | |||
| 527 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 528 | char base[256]; | ||
| 529 | char name[256]; | ||
| 530 | |||
| 531 | const int64_t C = op->ne[0]; | ||
| 532 | const int64_t H = op->src[0]->ne[1]; | ||
| 533 | |||
| 534 | switch (op->op) { | ||
| 535 | case GGML_OP_RWKV_WKV6: | ||
| 536 | { | ||
| 537 | GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); | ||
| 538 | GGML_ASSERT(C % H == 0); | ||
| 539 | GGML_ASSERT(C / H == 64); | ||
| 540 | |||
| 541 | snprintf(base, 256, "kernel_rwkv_wkv6_%s", ggml_type_name(op->src[0]->type)); | ||
| 542 | } break; | ||
| 543 | case GGML_OP_RWKV_WKV7: | ||
| 544 | { | ||
| 545 | GGML_ASSERT(op->src[6]->type == GGML_TYPE_F32); | ||
| 546 | GGML_ASSERT(C % H == 0); | ||
| 547 | GGML_ASSERT(C / H == 64); | ||
| 548 | |||
| 549 | snprintf(base, 256, "kernel_rwkv_wkv7_%s", ggml_type_name(op->src[0]->type)); | ||
| 550 | } break; | ||
| 551 | default: | ||
| 552 | GGML_ABORT("fatal error"); | ||
| 553 | } | ||
| 554 | |||
| 555 | snprintf(name, 256, "%s", base); | ||
| 556 | |||
| 557 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 558 | if (!res.pipeline) { | ||
| 559 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 560 | } | ||
| 561 | |||
| 562 | return res; | ||
| 563 | } | ||
| 564 | |||
| 565 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 566 | char base[256]; | ||
| 567 | char name[256]; | ||
| 568 | |||
| 569 | const int nsg = 8; | ||
| 570 | const int n = op->src[1]->ne[1]; | ||
| 571 | const int k = op->src[1]->ne[0]; | ||
| 572 | |||
| 573 | snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type)); | ||
| 574 | snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k); | ||
| 575 | |||
| 576 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 577 | if (!res.pipeline) { | ||
| 578 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 579 | |||
| 580 | ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0); | ||
| 581 | ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1); | ||
| 582 | ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2); | ||
| 583 | |||
| 584 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 585 | |||
| 586 | ggml_metal_cv_free(cv); | ||
| 587 | } | ||
| 588 | |||
| 589 | res.nsg = nsg; | ||
| 590 | res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16); | ||
| 591 | |||
| 592 | return res; | ||
| 593 | } | ||
| 594 | |||
| 595 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { | ||
| 596 | char base[256]; | ||
| 597 | char name[256]; | ||
| 598 | |||
| 599 | snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); | ||
| 600 | snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); | ||
| 601 | |||
| 602 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 603 | if (!res.pipeline) { | ||
| 604 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 605 | |||
| 606 | ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); | ||
| 607 | ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); | ||
| 608 | |||
| 609 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 610 | |||
| 611 | ggml_metal_cv_free(cv); | ||
| 612 | } | ||
| 613 | |||
| 614 | return res; | ||
| 615 | } | ||
| 616 | |||
| 617 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 618 | char base[256]; | ||
| 619 | char name[256]; | ||
| 620 | |||
| 621 | const ggml_type tsrc0 = op->src[0]->type; | ||
| 622 | const ggml_type tsrc1 = op->src[1]->type; | ||
| 623 | |||
| 624 | const bool bc_inp = op->src[0]->ne[0] % 32 != 0; | ||
| 625 | const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0; | ||
| 626 | |||
| 627 | snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); | ||
| 628 | snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); | ||
| 629 | |||
| 630 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 631 | if (!res.pipeline) { | ||
| 632 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 633 | |||
| 634 | ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); | ||
| 635 | ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); | ||
| 636 | |||
| 637 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 638 | |||
| 639 | ggml_metal_cv_free(cv); | ||
| 640 | } | ||
| 641 | |||
| 642 | // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes | ||
| 643 | res.smem = bc_out ? 8192 : 4096 + 2048; | ||
| 644 | |||
| 645 | return res; | ||
| 646 | } | ||
| 647 | |||
| 648 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 649 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 650 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 651 | |||
| 652 | char base[256]; | ||
| 653 | char name[256]; | ||
| 654 | |||
| 655 | int nsg = 0; // number of simdgroups | ||
| 656 | int nr0 = 0; // number of src0 rows per simdgroup | ||
| 657 | int nr1 = 1; // number of src1 rows per threadgroup | ||
| 658 | |||
| 659 | size_t smem = 0; // shared memory | ||
| 660 | |||
| 661 | const ggml_type tsrc0 = op->src[0]->type; | ||
| 662 | const ggml_type tsrc1 = op->src[1]->type; | ||
| 663 | |||
| 664 | const char * suffix = ""; | ||
| 665 | |||
| 666 | // use custom matrix x vector kernel | ||
| 667 | switch (tsrc0) { | ||
| 668 | case GGML_TYPE_F32: | ||
| 669 | case GGML_TYPE_F16: | ||
| 670 | case GGML_TYPE_BF16: | ||
| 671 | { | ||
| 672 | if (ne00 < 32) { | ||
| 673 | nsg = 1; | ||
| 674 | nr0 = 32; | ||
| 675 | nr1 = 1; | ||
| 676 | suffix = "_short"; | ||
| 677 | } else { | ||
| 678 | nsg = std::min(4, (ne00 + 127) / 128); | ||
| 679 | nr0 = 2; | ||
| 680 | nr1 = 1; | ||
| 681 | smem = 32*sizeof(float)*nr0; | ||
| 682 | suffix = ne00 % 4 == 0 ? "_4" : ""; | ||
| 683 | } | ||
| 684 | } break; | ||
| 685 | case GGML_TYPE_Q4_0: | ||
| 686 | { | ||
| 687 | nsg = N_SG_Q4_0; | ||
| 688 | nr0 = N_R0_Q4_0; | ||
| 689 | } break; | ||
| 690 | case GGML_TYPE_Q4_1: | ||
| 691 | { | ||
| 692 | nsg = N_SG_Q4_1; | ||
| 693 | nr0 = N_R0_Q4_1; | ||
| 694 | } break; | ||
| 695 | case GGML_TYPE_Q5_0: | ||
| 696 | { | ||
| 697 | nsg = N_SG_Q5_0; | ||
| 698 | nr0 = N_R0_Q5_0; | ||
| 699 | } break; | ||
| 700 | case GGML_TYPE_Q5_1: | ||
| 701 | { | ||
| 702 | nsg = N_SG_Q5_1; | ||
| 703 | nr0 = N_R0_Q5_1; | ||
| 704 | } break; | ||
| 705 | case GGML_TYPE_Q8_0: | ||
| 706 | { | ||
| 707 | nsg = N_SG_Q8_0; | ||
| 708 | nr0 = N_R0_Q8_0; | ||
| 709 | smem = 32*sizeof(float)*N_R0_Q8_0; | ||
| 710 | } break; | ||
| 711 | case GGML_TYPE_MXFP4: | ||
| 712 | { | ||
| 713 | nsg = N_SG_MXFP4; | ||
| 714 | nr0 = N_R0_MXFP4; | ||
| 715 | smem = 32*sizeof(float); | ||
| 716 | } break; | ||
| 717 | case GGML_TYPE_Q2_K: | ||
| 718 | { | ||
| 719 | nsg = N_SG_Q2_K; | ||
| 720 | nr0 = N_R0_Q2_K; | ||
| 721 | } break; | ||
| 722 | case GGML_TYPE_Q3_K: | ||
| 723 | { | ||
| 724 | nsg = N_SG_Q3_K; | ||
| 725 | nr0 = N_R0_Q3_K; | ||
| 726 | } break; | ||
| 727 | case GGML_TYPE_Q4_K: | ||
| 728 | { | ||
| 729 | nsg = N_SG_Q4_K; | ||
| 730 | nr0 = N_R0_Q4_K; | ||
| 731 | } break; | ||
| 732 | case GGML_TYPE_Q5_K: | ||
| 733 | { | ||
| 734 | nsg = N_SG_Q5_K; | ||
| 735 | nr0 = N_R0_Q5_K; | ||
| 736 | } break; | ||
| 737 | case GGML_TYPE_Q6_K: | ||
| 738 | { | ||
| 739 | nsg = N_SG_Q6_K; | ||
| 740 | nr0 = N_R0_Q6_K; | ||
| 741 | } break; | ||
| 742 | case GGML_TYPE_IQ2_XXS: | ||
| 743 | { | ||
| 744 | nsg = N_SG_IQ2_XXS; | ||
| 745 | nr0 = N_R0_IQ2_XXS; | ||
| 746 | smem = 256*8+128; | ||
| 747 | } break; | ||
| 748 | case GGML_TYPE_IQ2_XS: | ||
| 749 | { | ||
| 750 | nsg = N_SG_IQ2_XS; | ||
| 751 | nr0 = N_R0_IQ2_XS; | ||
| 752 | smem = 512*8+128; | ||
| 753 | } break; | ||
| 754 | case GGML_TYPE_IQ3_XXS: | ||
| 755 | { | ||
| 756 | nsg = N_SG_IQ3_XXS; | ||
| 757 | nr0 = N_R0_IQ3_XXS; | ||
| 758 | smem = 256*4+128; | ||
| 759 | } break; | ||
| 760 | case GGML_TYPE_IQ3_S: | ||
| 761 | { | ||
| 762 | nsg = N_SG_IQ3_S; | ||
| 763 | nr0 = N_R0_IQ3_S; | ||
| 764 | smem = 512*4; | ||
| 765 | } break; | ||
| 766 | case GGML_TYPE_IQ2_S: | ||
| 767 | { | ||
| 768 | nsg = N_SG_IQ2_S; | ||
| 769 | nr0 = N_R0_IQ2_S; | ||
| 770 | } break; | ||
| 771 | case GGML_TYPE_IQ1_S: | ||
| 772 | { | ||
| 773 | nsg = N_SG_IQ1_S; | ||
| 774 | nr0 = N_R0_IQ1_S; | ||
| 775 | } break; | ||
| 776 | case GGML_TYPE_IQ1_M: | ||
| 777 | { | ||
| 778 | nsg = N_SG_IQ1_M; | ||
| 779 | nr0 = N_R0_IQ1_M; | ||
| 780 | } break; | ||
| 781 | case GGML_TYPE_IQ4_NL: | ||
| 782 | { | ||
| 783 | nsg = N_SG_IQ4_NL; | ||
| 784 | nr0 = N_R0_IQ4_NL; | ||
| 785 | smem = 32*sizeof(float); | ||
| 786 | } break; | ||
| 787 | case GGML_TYPE_IQ4_XS: | ||
| 788 | { | ||
| 789 | nsg = N_SG_IQ4_XS; | ||
| 790 | nr0 = N_R0_IQ4_XS; | ||
| 791 | smem = 32*sizeof(float); | ||
| 792 | } break; | ||
| 793 | default: | ||
| 794 | { | ||
| 795 | GGML_LOG_ERROR("Asserting on type %d\n", (int) tsrc0); | ||
| 796 | GGML_ABORT("not implemented"); | ||
| 797 | } | ||
| 798 | }; | ||
| 799 | |||
| 800 | snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); | ||
| 801 | snprintf(name, 256, "%s_nsg=%d", base, nsg); | ||
| 802 | |||
| 803 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 804 | if (!res.pipeline) { | ||
| 805 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 806 | |||
| 807 | ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); | ||
| 808 | |||
| 809 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 810 | |||
| 811 | ggml_metal_cv_free(cv); | ||
| 812 | } | ||
| 813 | |||
| 814 | res.nr0 = nr0; | ||
| 815 | res.nr1 = nr1; | ||
| 816 | res.nsg = nsg; | ||
| 817 | res.smem = smem; | ||
| 818 | |||
| 819 | return res; | ||
| 820 | } | ||
| 821 | |||
| 822 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) { | ||
| 823 | char base[256]; | ||
| 824 | char name[256]; | ||
| 825 | |||
| 826 | snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20); | ||
| 827 | snprintf(name, 256, "%s_ne02=%d", base, ne02); | ||
| 828 | |||
| 829 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 830 | if (!res.pipeline) { | ||
| 831 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 832 | } | ||
| 833 | |||
| 834 | res.smem = (size_t) ne02*ne20*sizeof(uint16_t); | ||
| 835 | |||
| 836 | return res; | ||
| 837 | } | ||
| 838 | |||
| 839 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 840 | char base[256]; | ||
| 841 | char name[256]; | ||
| 842 | |||
| 843 | const ggml_type tsrc0 = op->src[0]->type; | ||
| 844 | const ggml_type tsrc1 = op->src[1]->type; | ||
| 845 | |||
| 846 | const bool bc_inp = op->src[0]->ne[0] % 32 != 0; | ||
| 847 | |||
| 848 | snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); | ||
| 849 | snprintf(name, 256, "%s_bci=%d", base, bc_inp); | ||
| 850 | |||
| 851 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 852 | if (!res.pipeline) { | ||
| 853 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 854 | |||
| 855 | ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); | ||
| 856 | |||
| 857 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 858 | |||
| 859 | ggml_metal_cv_free(cv); | ||
| 860 | } | ||
| 861 | |||
| 862 | res.smem = 8192; | ||
| 863 | |||
| 864 | return res; | ||
| 865 | } | ||
| 866 | |||
| 867 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 868 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 869 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 870 | |||
| 871 | char base[256]; | ||
| 872 | char name[256]; | ||
| 873 | |||
| 874 | int nsg = 0; // number of simdgroups | ||
| 875 | int nr0 = 0; // number of src0 rows per simdgroup | ||
| 876 | int nr1 = 1; // number of src1 rows per threadgroup | ||
| 877 | |||
| 878 | size_t smem = 0; // shared memory | ||
| 879 | |||
| 880 | const ggml_type tsrc0 = op->src[0]->type; | ||
| 881 | const ggml_type tsrc1 = op->src[1]->type; | ||
| 882 | |||
| 883 | const char * suffix = ""; | ||
| 884 | |||
| 885 | // use custom matrix x vector kernel | ||
| 886 | switch (tsrc0) { | ||
| 887 | case GGML_TYPE_F32: | ||
| 888 | case GGML_TYPE_F16: | ||
| 889 | case GGML_TYPE_BF16: | ||
| 890 | { | ||
| 891 | nsg = std::min(4, (ne00 + 127) / 128); | ||
| 892 | nr0 = 2; | ||
| 893 | nr1 = 1; | ||
| 894 | smem = 32*sizeof(float)*nr0; | ||
| 895 | suffix = ne00 % 4 == 0 ? "_4" : ""; | ||
| 896 | } break; | ||
| 897 | case GGML_TYPE_Q4_0: | ||
| 898 | { | ||
| 899 | nsg = N_SG_Q4_0; | ||
| 900 | nr0 = N_R0_Q4_0; | ||
| 901 | } break; | ||
| 902 | case GGML_TYPE_Q4_1: | ||
| 903 | { | ||
| 904 | nsg = N_SG_Q4_1; | ||
| 905 | nr0 = N_R0_Q4_1; | ||
| 906 | } break; | ||
| 907 | case GGML_TYPE_Q5_0: | ||
| 908 | { | ||
| 909 | nsg = N_SG_Q5_0; | ||
| 910 | nr0 = N_R0_Q5_0; | ||
| 911 | } break; | ||
| 912 | case GGML_TYPE_Q5_1: | ||
| 913 | { | ||
| 914 | nsg = N_SG_Q5_1; | ||
| 915 | nr0 = N_R0_Q5_1; | ||
| 916 | } break; | ||
| 917 | case GGML_TYPE_Q8_0: | ||
| 918 | { | ||
| 919 | nsg = N_SG_Q8_0; | ||
| 920 | nr0 = N_R0_Q8_0; | ||
| 921 | smem = 32*sizeof(float)*N_R0_Q8_0; | ||
| 922 | } break; | ||
| 923 | case GGML_TYPE_MXFP4: | ||
| 924 | { | ||
| 925 | nsg = N_SG_MXFP4; | ||
| 926 | nr0 = N_R0_MXFP4; | ||
| 927 | smem = 32*sizeof(float); | ||
| 928 | } break; | ||
| 929 | case GGML_TYPE_Q2_K: | ||
| 930 | { | ||
| 931 | nsg = N_SG_Q2_K; | ||
| 932 | nr0 = N_R0_Q2_K; | ||
| 933 | } break; | ||
| 934 | case GGML_TYPE_Q3_K: | ||
| 935 | { | ||
| 936 | nsg = N_SG_Q3_K; | ||
| 937 | nr0 = N_R0_Q3_K; | ||
| 938 | } break; | ||
| 939 | case GGML_TYPE_Q4_K: | ||
| 940 | { | ||
| 941 | nsg = N_SG_Q4_K; | ||
| 942 | nr0 = N_R0_Q4_K; | ||
| 943 | } break; | ||
| 944 | case GGML_TYPE_Q5_K: | ||
| 945 | { | ||
| 946 | nsg = N_SG_Q5_K; | ||
| 947 | nr0 = N_R0_Q5_K; | ||
| 948 | } break; | ||
| 949 | case GGML_TYPE_Q6_K: | ||
| 950 | { | ||
| 951 | nsg = N_SG_Q6_K; | ||
| 952 | nr0 = N_R0_Q6_K; | ||
| 953 | } break; | ||
| 954 | case GGML_TYPE_IQ2_XXS: | ||
| 955 | { | ||
| 956 | nsg = N_SG_IQ2_XXS; | ||
| 957 | nr0 = N_R0_IQ2_XXS; | ||
| 958 | smem = 256*8+128; | ||
| 959 | } break; | ||
| 960 | case GGML_TYPE_IQ2_XS: | ||
| 961 | { | ||
| 962 | nsg = N_SG_IQ2_XS; | ||
| 963 | nr0 = N_R0_IQ2_XS; | ||
| 964 | smem = 512*8+128; | ||
| 965 | } break; | ||
| 966 | case GGML_TYPE_IQ3_XXS: | ||
| 967 | { | ||
| 968 | nsg = N_SG_IQ3_XXS; | ||
| 969 | nr0 = N_R0_IQ3_XXS; | ||
| 970 | smem = 256*4+128; | ||
| 971 | } break; | ||
| 972 | case GGML_TYPE_IQ3_S: | ||
| 973 | { | ||
| 974 | nsg = N_SG_IQ3_S; | ||
| 975 | nr0 = N_R0_IQ3_S; | ||
| 976 | smem = 512*4; | ||
| 977 | } break; | ||
| 978 | case GGML_TYPE_IQ2_S: | ||
| 979 | { | ||
| 980 | nsg = N_SG_IQ2_S; | ||
| 981 | nr0 = N_R0_IQ2_S; | ||
| 982 | } break; | ||
| 983 | case GGML_TYPE_IQ1_S: | ||
| 984 | { | ||
| 985 | nsg = N_SG_IQ1_S; | ||
| 986 | nr0 = N_R0_IQ1_S; | ||
| 987 | } break; | ||
| 988 | case GGML_TYPE_IQ1_M: | ||
| 989 | { | ||
| 990 | nsg = N_SG_IQ1_M; | ||
| 991 | nr0 = N_R0_IQ1_M; | ||
| 992 | } break; | ||
| 993 | case GGML_TYPE_IQ4_NL: | ||
| 994 | { | ||
| 995 | nsg = N_SG_IQ4_NL; | ||
| 996 | nr0 = N_R0_IQ4_NL; | ||
| 997 | smem = 32*sizeof(float); | ||
| 998 | } break; | ||
| 999 | case GGML_TYPE_IQ4_XS: | ||
| 1000 | { | ||
| 1001 | nsg = N_SG_IQ4_XS; | ||
| 1002 | nr0 = N_R0_IQ4_XS; | ||
| 1003 | smem = 32*sizeof(float); | ||
| 1004 | } break; | ||
| 1005 | default: | ||
| 1006 | { | ||
| 1007 | GGML_LOG_ERROR("Asserting on type %d\n", (int)op->src[2]->type); | ||
| 1008 | GGML_ABORT("not implemented"); | ||
| 1009 | } | ||
| 1010 | }; | ||
| 1011 | |||
| 1012 | snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); | ||
| 1013 | snprintf(name, 256, "%s_nsg=%d", base, nsg); | ||
| 1014 | |||
| 1015 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1016 | if (!res.pipeline) { | ||
| 1017 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1018 | |||
| 1019 | ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); | ||
| 1020 | |||
| 1021 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1022 | |||
| 1023 | ggml_metal_cv_free(cv); | ||
| 1024 | } | ||
| 1025 | |||
| 1026 | res.nr0 = nr0; | ||
| 1027 | res.nr1 = nr1; | ||
| 1028 | res.nsg = nsg; | ||
| 1029 | res.smem = smem; | ||
| 1030 | |||
| 1031 | return res; | ||
| 1032 | } | ||
| 1033 | |||
| 1034 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1035 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); | ||
| 1036 | GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); | ||
| 1037 | GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); | ||
| 1038 | |||
| 1039 | char base[256]; | ||
| 1040 | char name[256]; | ||
| 1041 | |||
| 1042 | snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type)); | ||
| 1043 | snprintf(name, 256, "%s", base); | ||
| 1044 | |||
| 1045 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1046 | if (!res.pipeline) { | ||
| 1047 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1048 | } | ||
| 1049 | |||
| 1050 | res.smem = 32*(sizeof(float) + sizeof(int32_t)); | ||
| 1051 | |||
| 1052 | return res; | ||
| 1053 | } | ||
| 1054 | |||
| 1055 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1056 | assert(op->op == GGML_OP_ARGSORT); | ||
| 1057 | |||
| 1058 | char base[256]; | ||
| 1059 | char name[256]; | ||
| 1060 | |||
| 1061 | ggml_sort_order order = (ggml_sort_order) op->op_params[0]; | ||
| 1062 | |||
| 1063 | const char * order_str = "undefined"; | ||
| 1064 | switch (order) { | ||
| 1065 | case GGML_SORT_ORDER_ASC: order_str = "asc"; break; | ||
| 1066 | case GGML_SORT_ORDER_DESC: order_str = "desc"; break; | ||
| 1067 | default: GGML_ABORT("fatal error"); | ||
| 1068 | }; | ||
| 1069 | |||
| 1070 | snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); | ||
| 1071 | snprintf(name, 256, "%s", base); | ||
| 1072 | |||
| 1073 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1074 | if (!res.pipeline) { | ||
| 1075 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1076 | } | ||
| 1077 | |||
| 1078 | return res; | ||
| 1079 | } | ||
| 1080 | |||
| 1081 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1082 | assert(op->op == GGML_OP_ARGSORT); | ||
| 1083 | |||
| 1084 | char base[256]; | ||
| 1085 | char name[256]; | ||
| 1086 | |||
| 1087 | ggml_sort_order order = (ggml_sort_order) op->op_params[0]; | ||
| 1088 | |||
| 1089 | const char * order_str = "undefined"; | ||
| 1090 | switch (order) { | ||
| 1091 | case GGML_SORT_ORDER_ASC: order_str = "asc"; break; | ||
| 1092 | case GGML_SORT_ORDER_DESC: order_str = "desc"; break; | ||
| 1093 | default: GGML_ABORT("fatal error"); | ||
| 1094 | }; | ||
| 1095 | |||
| 1096 | snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); | ||
| 1097 | snprintf(name, 256, "%s", base); | ||
| 1098 | |||
| 1099 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1100 | if (!res.pipeline) { | ||
| 1101 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1102 | } | ||
| 1103 | |||
| 1104 | return res; | ||
| 1105 | } | ||
| 1106 | |||
| 1107 | // note: reuse the argsort kernel for top_k | ||
| 1108 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1109 | assert(op->op == GGML_OP_TOP_K); | ||
| 1110 | |||
| 1111 | char base[256]; | ||
| 1112 | char name[256]; | ||
| 1113 | |||
| 1114 | // note: the top_k kernel is always descending order | ||
| 1115 | ggml_sort_order order = GGML_SORT_ORDER_DESC; | ||
| 1116 | |||
| 1117 | const char * order_str = "undefined"; | ||
| 1118 | switch (order) { | ||
| 1119 | case GGML_SORT_ORDER_ASC: order_str = "asc"; break; | ||
| 1120 | case GGML_SORT_ORDER_DESC: order_str = "desc"; break; | ||
| 1121 | default: GGML_ABORT("fatal error"); | ||
| 1122 | }; | ||
| 1123 | |||
| 1124 | snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); | ||
| 1125 | snprintf(name, 256, "%s", base); | ||
| 1126 | |||
| 1127 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1128 | if (!res.pipeline) { | ||
| 1129 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1130 | } | ||
| 1131 | |||
| 1132 | return res; | ||
| 1133 | } | ||
| 1134 | |||
| 1135 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1136 | assert(op->op == GGML_OP_TOP_K); | ||
| 1137 | |||
| 1138 | char base[256]; | ||
| 1139 | char name[256]; | ||
| 1140 | |||
| 1141 | ggml_sort_order order = GGML_SORT_ORDER_DESC; | ||
| 1142 | |||
| 1143 | const char * order_str = "undefined"; | ||
| 1144 | switch (order) { | ||
| 1145 | case GGML_SORT_ORDER_ASC: order_str = "asc"; break; | ||
| 1146 | case GGML_SORT_ORDER_DESC: order_str = "desc"; break; | ||
| 1147 | default: GGML_ABORT("fatal error"); | ||
| 1148 | }; | ||
| 1149 | |||
| 1150 | snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str); | ||
| 1151 | snprintf(name, 256, "%s", base); | ||
| 1152 | |||
| 1153 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1154 | if (!res.pipeline) { | ||
| 1155 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1156 | } | ||
| 1157 | |||
| 1158 | return res; | ||
| 1159 | } | ||
| 1160 | |||
| 1161 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( | ||
| 1162 | ggml_metal_library_t lib, | ||
| 1163 | const struct ggml_tensor * op, | ||
| 1164 | bool has_mask, | ||
| 1165 | int32_t ncpsg) { | ||
| 1166 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 1167 | GGML_UNUSED(op); | ||
| 1168 | |||
| 1169 | char base[256]; | ||
| 1170 | char name[256]; | ||
| 1171 | |||
| 1172 | snprintf(base, 256, "kernel_%s", | ||
| 1173 | "flash_attn_ext_pad"); | ||
| 1174 | |||
| 1175 | snprintf(name, 256, "%s_mask=%d_ncpsg=%d", | ||
| 1176 | base, | ||
| 1177 | has_mask, | ||
| 1178 | ncpsg); | ||
| 1179 | |||
| 1180 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1181 | if (!res.pipeline) { | ||
| 1182 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1183 | |||
| 1184 | ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); | ||
| 1185 | //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); | ||
| 1186 | //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); | ||
| 1187 | //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); | ||
| 1188 | |||
| 1189 | //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); | ||
| 1190 | //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); | ||
| 1191 | //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); | ||
| 1192 | //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); | ||
| 1193 | //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); | ||
| 1194 | ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); | ||
| 1195 | |||
| 1196 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1197 | |||
| 1198 | ggml_metal_cv_free(cv); | ||
| 1199 | } | ||
| 1200 | |||
| 1201 | return res; | ||
| 1202 | } | ||
| 1203 | |||
| 1204 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( | ||
| 1205 | ggml_metal_library_t lib, | ||
| 1206 | const struct ggml_tensor * op, | ||
| 1207 | int32_t nqptg, | ||
| 1208 | int32_t ncpsg) { | ||
| 1209 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 1210 | GGML_UNUSED(op); | ||
| 1211 | |||
| 1212 | char base[256]; | ||
| 1213 | char name[256]; | ||
| 1214 | |||
| 1215 | snprintf(base, 256, "kernel_%s", | ||
| 1216 | "flash_attn_ext_blk"); | ||
| 1217 | |||
| 1218 | snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d", | ||
| 1219 | base, | ||
| 1220 | nqptg, | ||
| 1221 | ncpsg); | ||
| 1222 | |||
| 1223 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1224 | if (!res.pipeline) { | ||
| 1225 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1226 | |||
| 1227 | //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); | ||
| 1228 | //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); | ||
| 1229 | //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); | ||
| 1230 | //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); | ||
| 1231 | |||
| 1232 | //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); | ||
| 1233 | //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); | ||
| 1234 | //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); | ||
| 1235 | //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); | ||
| 1236 | ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); | ||
| 1237 | ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); | ||
| 1238 | |||
| 1239 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1240 | |||
| 1241 | ggml_metal_cv_free(cv); | ||
| 1242 | } | ||
| 1243 | |||
| 1244 | return res; | ||
| 1245 | } | ||
| 1246 | |||
| 1247 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( | ||
| 1248 | ggml_metal_library_t lib, | ||
| 1249 | const ggml_tensor * op, | ||
| 1250 | bool has_mask, | ||
| 1251 | bool has_sinks, | ||
| 1252 | bool has_bias, | ||
| 1253 | bool has_scap, | ||
| 1254 | bool has_kvpad, | ||
| 1255 | int32_t nsg) { | ||
| 1256 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 1257 | |||
| 1258 | char base[256]; | ||
| 1259 | char name[256]; | ||
| 1260 | |||
| 1261 | const int32_t dk = (int32_t) op->src[1]->ne[0]; | ||
| 1262 | const int32_t dv = (int32_t) op->src[2]->ne[0]; | ||
| 1263 | |||
| 1264 | const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; | ||
| 1265 | const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; | ||
| 1266 | |||
| 1267 | // do bounds checks for the mask? | ||
| 1268 | const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0); | ||
| 1269 | |||
| 1270 | snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", | ||
| 1271 | "flash_attn_ext", | ||
| 1272 | ggml_type_name(op->src[1]->type), | ||
| 1273 | dk, | ||
| 1274 | dv); | ||
| 1275 | |||
| 1276 | snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d", | ||
| 1277 | base, | ||
| 1278 | has_mask, | ||
| 1279 | has_sinks, | ||
| 1280 | has_bias, | ||
| 1281 | has_scap, | ||
| 1282 | has_kvpad, | ||
| 1283 | bc_mask, | ||
| 1284 | ns10, | ||
| 1285 | ns20, | ||
| 1286 | nsg); | ||
| 1287 | |||
| 1288 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1289 | if (!res.pipeline) { | ||
| 1290 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1291 | |||
| 1292 | ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT + 0); | ||
| 1293 | ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); | ||
| 1294 | ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); | ||
| 1295 | ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); | ||
| 1296 | ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); | ||
| 1297 | |||
| 1298 | ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); | ||
| 1299 | |||
| 1300 | ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); | ||
| 1301 | ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); | ||
| 1302 | ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT + 22); | ||
| 1303 | |||
| 1304 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1305 | |||
| 1306 | ggml_metal_cv_free(cv); | ||
| 1307 | } | ||
| 1308 | |||
| 1309 | return res; | ||
| 1310 | } | ||
| 1311 | |||
| 1312 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( | ||
| 1313 | ggml_metal_library_t lib, | ||
| 1314 | const ggml_tensor * op, | ||
| 1315 | bool has_mask, | ||
| 1316 | bool has_sinks, | ||
| 1317 | bool has_bias, | ||
| 1318 | bool has_scap, | ||
| 1319 | bool has_kvpad, | ||
| 1320 | int32_t nsg, | ||
| 1321 | int32_t nwg) { | ||
| 1322 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 1323 | |||
| 1324 | char base[256]; | ||
| 1325 | char name[256]; | ||
| 1326 | |||
| 1327 | const int32_t dk = (int32_t) op->src[1]->ne[0]; | ||
| 1328 | const int32_t dv = (int32_t) op->src[2]->ne[0]; | ||
| 1329 | |||
| 1330 | const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; | ||
| 1331 | const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; | ||
| 1332 | |||
| 1333 | snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", | ||
| 1334 | "flash_attn_ext_vec", | ||
| 1335 | ggml_type_name(op->src[1]->type), | ||
| 1336 | dk, | ||
| 1337 | dv); | ||
| 1338 | |||
| 1339 | snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", | ||
| 1340 | base, | ||
| 1341 | has_mask, | ||
| 1342 | has_sinks, | ||
| 1343 | has_bias, | ||
| 1344 | has_scap, | ||
| 1345 | has_kvpad, | ||
| 1346 | ns10, | ||
| 1347 | ns20, | ||
| 1348 | nsg, nwg); | ||
| 1349 | |||
| 1350 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1351 | if (!res.pipeline) { | ||
| 1352 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1353 | |||
| 1354 | ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_VEC + 0); | ||
| 1355 | ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); | ||
| 1356 | ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); | ||
| 1357 | ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); | ||
| 1358 | ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); | ||
| 1359 | |||
| 1360 | ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); | ||
| 1361 | ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); | ||
| 1362 | ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_VEC + 22); | ||
| 1363 | ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC + 23); | ||
| 1364 | |||
| 1365 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1366 | |||
| 1367 | ggml_metal_cv_free(cv); | ||
| 1368 | } | ||
| 1369 | |||
| 1370 | return res; | ||
| 1371 | } | ||
| 1372 | |||
| 1373 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( | ||
| 1374 | ggml_metal_library_t lib, | ||
| 1375 | const ggml_tensor * op, | ||
| 1376 | int32_t dv, | ||
| 1377 | int32_t nwg) { | ||
| 1378 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 1379 | |||
| 1380 | char base[256]; | ||
| 1381 | char name[256]; | ||
| 1382 | |||
| 1383 | snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce"); | ||
| 1384 | snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg); | ||
| 1385 | |||
| 1386 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1387 | if (!res.pipeline) { | ||
| 1388 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1389 | |||
| 1390 | ggml_metal_cv_set_int32(cv, dv, FC_FLASH_ATTN_EXT_VEC_REDUCE + 0); | ||
| 1391 | ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1); | ||
| 1392 | |||
| 1393 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1394 | |||
| 1395 | ggml_metal_cv_free(cv); | ||
| 1396 | } | ||
| 1397 | |||
| 1398 | return res; | ||
| 1399 | |||
| 1400 | GGML_UNUSED(op); | ||
| 1401 | } | ||
| 1402 | |||
| 1403 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { | ||
| 1404 | char base[256]; | ||
| 1405 | char name[256]; | ||
| 1406 | |||
| 1407 | int op_num = -1; | ||
| 1408 | |||
| 1409 | switch (op->op) { | ||
| 1410 | case GGML_OP_ADD: op_num = 0; break; | ||
| 1411 | case GGML_OP_SUB: op_num = 1; break; | ||
| 1412 | case GGML_OP_MUL: op_num = 2; break; | ||
| 1413 | case GGML_OP_DIV: op_num = 3; break; | ||
| 1414 | default: GGML_ABORT("fatal error"); | ||
| 1415 | }; | ||
| 1416 | |||
| 1417 | const char * t0_str = ggml_type_name(op->src[0]->type); | ||
| 1418 | const char * t1_str = ggml_type_name(op->src[1]->type); | ||
| 1419 | const char * t_str = ggml_type_name(op->type); | ||
| 1420 | |||
| 1421 | const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); | ||
| 1422 | |||
| 1423 | const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; | ||
| 1424 | |||
| 1425 | snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); | ||
| 1426 | snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); | ||
| 1427 | |||
| 1428 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1429 | if (!res.pipeline) { | ||
| 1430 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1431 | |||
| 1432 | ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); | ||
| 1433 | ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); | ||
| 1434 | ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); | ||
| 1435 | |||
| 1436 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1437 | |||
| 1438 | ggml_metal_cv_free(cv); | ||
| 1439 | } | ||
| 1440 | |||
| 1441 | res.c4 = is_c4; | ||
| 1442 | res.cnt = is_rb; | ||
| 1443 | |||
| 1444 | return res; | ||
| 1445 | } | ||
| 1446 | |||
| 1447 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { | ||
| 1448 | char base[256]; | ||
| 1449 | char name[256]; | ||
| 1450 | |||
| 1451 | int op_num = -1; | ||
| 1452 | |||
| 1453 | switch (op) { | ||
| 1454 | case GGML_OP_ADD: op_num = 0; break; | ||
| 1455 | case GGML_OP_SUB: op_num = 1; break; | ||
| 1456 | case GGML_OP_MUL: op_num = 2; break; | ||
| 1457 | case GGML_OP_DIV: op_num = 3; break; | ||
| 1458 | default: GGML_ABORT("fatal error"); | ||
| 1459 | }; | ||
| 1460 | |||
| 1461 | snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); | ||
| 1462 | snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); | ||
| 1463 | |||
| 1464 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1465 | if (!res.pipeline) { | ||
| 1466 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1467 | |||
| 1468 | ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); | ||
| 1469 | ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1); | ||
| 1470 | ggml_metal_cv_set_bool (cv, false, FC_BIN + 2); | ||
| 1471 | |||
| 1472 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1473 | |||
| 1474 | ggml_metal_cv_free(cv); | ||
| 1475 | } | ||
| 1476 | |||
| 1477 | return res; | ||
| 1478 | } | ||
| 1479 | |||
| 1480 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1481 | assert(op->op == GGML_OP_L2_NORM); | ||
| 1482 | |||
| 1483 | char base[256]; | ||
| 1484 | char name[256]; | ||
| 1485 | |||
| 1486 | const bool is_c4 = op->src[0]->ne[0] % 4 == 0; | ||
| 1487 | |||
| 1488 | const char * t0_str = ggml_type_name(op->src[0]->type); | ||
| 1489 | const char * t_str = ggml_type_name(op->type); | ||
| 1490 | |||
| 1491 | snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); | ||
| 1492 | snprintf(name, 256, "%s", base); | ||
| 1493 | |||
| 1494 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1495 | if (!res.pipeline) { | ||
| 1496 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1497 | } | ||
| 1498 | |||
| 1499 | res.c4 = is_c4; | ||
| 1500 | res.smem = 32*sizeof(float); | ||
| 1501 | |||
| 1502 | return res; | ||
| 1503 | } | ||
| 1504 | |||
| 1505 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1506 | assert(op->op == GGML_OP_GROUP_NORM); | ||
| 1507 | |||
| 1508 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 1509 | |||
| 1510 | char base[256]; | ||
| 1511 | char name[256]; | ||
| 1512 | |||
| 1513 | snprintf(base, 256, "kernel_group_norm_f32"); | ||
| 1514 | snprintf(name, 256, "%s", base); | ||
| 1515 | |||
| 1516 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1517 | if (!res.pipeline) { | ||
| 1518 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1519 | } | ||
| 1520 | |||
| 1521 | res.smem = 32*sizeof(float); | ||
| 1522 | |||
| 1523 | return res; | ||
| 1524 | } | ||
| 1525 | |||
| 1526 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) { | ||
| 1527 | assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM); | ||
| 1528 | |||
| 1529 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 1530 | |||
| 1531 | char base[256]; | ||
| 1532 | char name[256]; | ||
| 1533 | |||
| 1534 | const char * suffix = ""; | ||
| 1535 | if (op->ne[0] % 4 == 0) { | ||
| 1536 | suffix = "_4"; | ||
| 1537 | } | ||
| 1538 | |||
| 1539 | switch (op->op) { | ||
| 1540 | case GGML_OP_NORM: | ||
| 1541 | switch (n_fuse) { | ||
| 1542 | case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break; | ||
| 1543 | case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break; | ||
| 1544 | case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break; | ||
| 1545 | default: GGML_ABORT("fatal error"); | ||
| 1546 | } break; | ||
| 1547 | case GGML_OP_RMS_NORM: | ||
| 1548 | switch (n_fuse) { | ||
| 1549 | case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break; | ||
| 1550 | case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break; | ||
| 1551 | case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break; | ||
| 1552 | default: GGML_ABORT("fatal error"); | ||
| 1553 | } break; | ||
| 1554 | default: GGML_ABORT("fatal error"); | ||
| 1555 | } | ||
| 1556 | |||
| 1557 | snprintf(name, 256, "%s", base); | ||
| 1558 | |||
| 1559 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1560 | if (!res.pipeline) { | ||
| 1561 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1562 | } | ||
| 1563 | |||
| 1564 | res.smem = 32*sizeof(float); | ||
| 1565 | |||
| 1566 | return res; | ||
| 1567 | } | ||
| 1568 | |||
| 1569 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1570 | assert(op->op == GGML_OP_ROPE); | ||
| 1571 | |||
| 1572 | char base[256]; | ||
| 1573 | char name[256]; | ||
| 1574 | |||
| 1575 | const int mode = ((const int32_t *) op->op_params)[2]; | ||
| 1576 | |||
| 1577 | const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; | ||
| 1578 | const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; | ||
| 1579 | const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; | ||
| 1580 | const bool is_vision = mode == GGML_ROPE_TYPE_VISION; | ||
| 1581 | |||
| 1582 | if (is_neox) { | ||
| 1583 | snprintf(base, 256, "kernel_rope_neox_%s", ggml_type_name(op->src[0]->type)); | ||
| 1584 | } else if ((is_mrope || is_imrope) && !is_vision) { | ||
| 1585 | GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token | ||
| 1586 | snprintf(base, 256, "kernel_rope_multi_%s", ggml_type_name(op->src[0]->type)); | ||
| 1587 | } else if (is_vision) { | ||
| 1588 | GGML_ASSERT(op->src[1]->ne[0]*4 >= op->src[0]->ne[2]); // need at least 4 pos per token | ||
| 1589 | snprintf(base, 256, "kernel_rope_vision_%s", ggml_type_name(op->src[0]->type)); | ||
| 1590 | } else { | ||
| 1591 | snprintf(base, 256, "kernel_rope_norm_%s", ggml_type_name(op->src[0]->type)); | ||
| 1592 | } | ||
| 1593 | |||
| 1594 | snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0); | ||
| 1595 | |||
| 1596 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1597 | if (!res.pipeline) { | ||
| 1598 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1599 | |||
| 1600 | ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0); | ||
| 1601 | |||
| 1602 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1603 | |||
| 1604 | ggml_metal_cv_free(cv); | ||
| 1605 | } | ||
| 1606 | |||
| 1607 | return res; | ||
| 1608 | } | ||
| 1609 | |||
| 1610 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1611 | assert(op->op == GGML_OP_IM2COL); | ||
| 1612 | |||
| 1613 | GGML_ASSERT(ggml_is_contiguous(op->src[1])); | ||
| 1614 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 1615 | GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); | ||
| 1616 | |||
| 1617 | char base[256]; | ||
| 1618 | char name[256]; | ||
| 1619 | |||
| 1620 | snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); | ||
| 1621 | snprintf(name, 256, "%s", base); | ||
| 1622 | |||
| 1623 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1624 | if (!res.pipeline) { | ||
| 1625 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1626 | } | ||
| 1627 | |||
| 1628 | return res; | ||
| 1629 | } | ||
| 1630 | |||
| 1631 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1632 | assert(op->op == GGML_OP_CONV_TRANSPOSE_1D); | ||
| 1633 | |||
| 1634 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 1635 | GGML_ASSERT(ggml_is_contiguous(op->src[1])); | ||
| 1636 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); | ||
| 1637 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 1638 | GGML_ASSERT(op->type == GGML_TYPE_F32); | ||
| 1639 | |||
| 1640 | char base[256]; | ||
| 1641 | char name[256]; | ||
| 1642 | |||
| 1643 | snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); | ||
| 1644 | snprintf(name, 256, "%s", base); | ||
| 1645 | |||
| 1646 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1647 | if (!res.pipeline) { | ||
| 1648 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1649 | } | ||
| 1650 | |||
| 1651 | return res; | ||
| 1652 | } | ||
| 1653 | |||
| 1654 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1655 | assert(op->op == GGML_OP_CONV_TRANSPOSE_2D); | ||
| 1656 | |||
| 1657 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 1658 | GGML_ASSERT(ggml_is_contiguous(op->src[1])); | ||
| 1659 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); | ||
| 1660 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 1661 | GGML_ASSERT(op->type == GGML_TYPE_F32); | ||
| 1662 | |||
| 1663 | char base[256]; | ||
| 1664 | char name[256]; | ||
| 1665 | |||
| 1666 | snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); | ||
| 1667 | snprintf(name, 256, "%s", base); | ||
| 1668 | |||
| 1669 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1670 | if (!res.pipeline) { | ||
| 1671 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1672 | } | ||
| 1673 | |||
| 1674 | return res; | ||
| 1675 | } | ||
| 1676 | |||
| 1677 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1678 | assert(op->op == GGML_OP_CONV_2D); | ||
| 1679 | |||
| 1680 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 1681 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); | ||
| 1682 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 1683 | GGML_ASSERT(op->type == GGML_TYPE_F32); | ||
| 1684 | |||
| 1685 | char base[256]; | ||
| 1686 | char name[256]; | ||
| 1687 | |||
| 1688 | snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); | ||
| 1689 | snprintf(name, 256, "%s", base); | ||
| 1690 | |||
| 1691 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1692 | if (!res.pipeline) { | ||
| 1693 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1694 | } | ||
| 1695 | |||
| 1696 | return res; | ||
| 1697 | } | ||
| 1698 | |||
| 1699 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1700 | assert(op->op == GGML_OP_UPSCALE); | ||
| 1701 | |||
| 1702 | char base[256]; | ||
| 1703 | char name[256]; | ||
| 1704 | |||
| 1705 | snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); | ||
| 1706 | snprintf(name, 256, "%s", base); | ||
| 1707 | |||
| 1708 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1709 | if (!res.pipeline) { | ||
| 1710 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1711 | } | ||
| 1712 | |||
| 1713 | return res; | ||
| 1714 | } | ||
| 1715 | |||
| 1716 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1717 | assert(op->op == GGML_OP_PAD); | ||
| 1718 | |||
| 1719 | char base[256]; | ||
| 1720 | char name[256]; | ||
| 1721 | |||
| 1722 | snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); | ||
| 1723 | snprintf(name, 256, "%s", base); | ||
| 1724 | |||
| 1725 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1726 | if (res.pipeline) { | ||
| 1727 | return res; | ||
| 1728 | } | ||
| 1729 | |||
| 1730 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1731 | |||
| 1732 | return res; | ||
| 1733 | } | ||
| 1734 | |||
| 1735 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1736 | assert(op->op == GGML_OP_PAD_REFLECT_1D); | ||
| 1737 | |||
| 1738 | char base[256]; | ||
| 1739 | char name[256]; | ||
| 1740 | |||
| 1741 | snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type)); | ||
| 1742 | snprintf(name, 256, "%s", base); | ||
| 1743 | |||
| 1744 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1745 | if (!res.pipeline) { | ||
| 1746 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1747 | } | ||
| 1748 | |||
| 1749 | return res; | ||
| 1750 | } | ||
| 1751 | |||
| 1752 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1753 | assert(op->op == GGML_OP_ARANGE); | ||
| 1754 | |||
| 1755 | char base[256]; | ||
| 1756 | char name[256]; | ||
| 1757 | |||
| 1758 | snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type)); | ||
| 1759 | snprintf(name, 256, "%s", base); | ||
| 1760 | |||
| 1761 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1762 | if (!res.pipeline) { | ||
| 1763 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1764 | } | ||
| 1765 | |||
| 1766 | return res; | ||
| 1767 | } | ||
| 1768 | |||
| 1769 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1770 | assert(op->op == GGML_OP_TIMESTEP_EMBEDDING); | ||
| 1771 | |||
| 1772 | char base[256]; | ||
| 1773 | char name[256]; | ||
| 1774 | |||
| 1775 | snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type)); | ||
| 1776 | snprintf(name, 256, "%s", base); | ||
| 1777 | |||
| 1778 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1779 | if (!res.pipeline) { | ||
| 1780 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1781 | } | ||
| 1782 | |||
| 1783 | return res; | ||
| 1784 | } | ||
| 1785 | |||
| 1786 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1787 | assert(op->op == GGML_OP_OPT_STEP_ADAMW); | ||
| 1788 | |||
| 1789 | char base[256]; | ||
| 1790 | char name[256]; | ||
| 1791 | |||
| 1792 | snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); | ||
| 1793 | snprintf(name, 256, "%s", base); | ||
| 1794 | |||
| 1795 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1796 | if (!res.pipeline) { | ||
| 1797 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1798 | } | ||
| 1799 | |||
| 1800 | return res; | ||
| 1801 | } | ||
| 1802 | |||
| 1803 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1804 | assert(op->op == GGML_OP_OPT_STEP_SGD); | ||
| 1805 | |||
| 1806 | char base[256]; | ||
| 1807 | char name[256]; | ||
| 1808 | |||
| 1809 | snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); | ||
| 1810 | snprintf(name, 256, "%s", base); | ||
| 1811 | |||
| 1812 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1813 | if (!res.pipeline) { | ||
| 1814 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1815 | } | ||
| 1816 | |||
| 1817 | return res; | ||
| 1818 | } | ||
| 1819 | |||
| 1820 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1821 | GGML_ASSERT(op->type == GGML_TYPE_I64); | ||
| 1822 | |||
| 1823 | char base[256]; | ||
| 1824 | char name[256]; | ||
| 1825 | |||
| 1826 | snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type)); | ||
| 1827 | snprintf(name, 256, "%s", base); | ||
| 1828 | |||
| 1829 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1830 | if (!res.pipeline) { | ||
| 1831 | res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||
| 1832 | } | ||
| 1833 | |||
| 1834 | return res; | ||
| 1835 | } | ||
| 1836 | |||
| 1837 | ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) { | ||
| 1838 | assert(op->op == GGML_OP_COUNT_EQUAL); | ||
| 1839 | |||
| 1840 | GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne); | ||
| 1841 | |||
| 1842 | GGML_ASSERT(op->src[0]->type == op->src[1]->type); | ||
| 1843 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32); | ||
| 1844 | GGML_ASSERT(op->type == GGML_TYPE_I64); | ||
| 1845 | |||
| 1846 | // note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int | ||
| 1847 | GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31)); | ||
| 1848 | |||
| 1849 | char base[256]; | ||
| 1850 | char name[256]; | ||
| 1851 | |||
| 1852 | int nsg = 1; | ||
| 1853 | while (32*nsg < ne00 && nsg < 32) { | ||
| 1854 | nsg *= 2; | ||
| 1855 | } | ||
| 1856 | |||
| 1857 | snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type)); | ||
| 1858 | snprintf(name, 256, "%s_nsg=%d", base, nsg); | ||
| 1859 | |||
| 1860 | ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); | ||
| 1861 | if (!res.pipeline) { | ||
| 1862 | ggml_metal_cv_t cv = ggml_metal_cv_init(); | ||
| 1863 | |||
| 1864 | ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0); | ||
| 1865 | |||
| 1866 | res = ggml_metal_library_compile_pipeline(lib, base, name, cv); | ||
| 1867 | |||
| 1868 | ggml_metal_cv_free(cv); | ||
| 1869 | } | ||
| 1870 | |||
| 1871 | res.smem = 32 * sizeof(int32_t); | ||
| 1872 | res.nsg = nsg; | ||
| 1873 | |||
| 1874 | return res; | ||
| 1875 | } | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h new file mode 100644 index 0000000..93d7f6a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.h | |||
| @@ -0,0 +1,290 @@ | |||
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include "ggml.h" | ||
| 4 | |||
| 5 | #ifdef __cplusplus | ||
| 6 | extern "C" { | ||
| 7 | #endif | ||
| 8 | |||
| 9 | struct ggml_metal_buffer_id { | ||
| 10 | void * metal; // id<MTLBuffer> | ||
| 11 | size_t offs; | ||
| 12 | }; | ||
| 13 | |||
| 14 | typedef struct ggml_metal_device * ggml_metal_device_t; | ||
| 15 | |||
| 16 | // | ||
| 17 | // MTLFunctionConstantValues wrapper | ||
| 18 | // | ||
| 19 | |||
| 20 | typedef struct ggml_metal_cv * ggml_metal_cv_t; | ||
| 21 | |||
| 22 | ggml_metal_cv_t ggml_metal_cv_init(void); | ||
| 23 | void ggml_metal_cv_free(ggml_metal_cv_t cv); | ||
| 24 | |||
| 25 | void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx); | ||
| 26 | void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx); | ||
| 27 | void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx); | ||
| 28 | |||
| 29 | // | ||
| 30 | // MTLComputePipelineState wrapper | ||
| 31 | // | ||
| 32 | |||
| 33 | typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t; | ||
| 34 | |||
| 35 | ggml_metal_pipeline_t ggml_metal_pipeline_init(void); | ||
| 36 | void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline); | ||
| 37 | |||
| 38 | // a collection of pipelines | ||
| 39 | typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t; | ||
| 40 | |||
| 41 | ggml_metal_pipelines_t ggml_metal_pipelines_init(void); | ||
| 42 | void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls); | ||
| 43 | |||
| 44 | void ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline); | ||
| 45 | ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name); | ||
| 46 | |||
| 47 | struct ggml_metal_pipeline_with_params { | ||
| 48 | ggml_metal_pipeline_t pipeline; | ||
| 49 | |||
| 50 | int nsg; | ||
| 51 | |||
| 52 | int nr0; | ||
| 53 | int nr1; | ||
| 54 | |||
| 55 | size_t smem; | ||
| 56 | |||
| 57 | bool c4; | ||
| 58 | bool cnt; | ||
| 59 | }; | ||
| 60 | |||
| 61 | int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); | ||
| 62 | |||
| 63 | // | ||
| 64 | // MTLCommandBuffer wrapper | ||
| 65 | // | ||
| 66 | |||
| 67 | typedef void * ggml_metal_cmd_buf_t; | ||
| 68 | |||
| 69 | // | ||
| 70 | // MTLComputeCommandEncoder wrapper | ||
| 71 | // | ||
| 72 | |||
| 73 | typedef struct ggml_metal_encoder * ggml_metal_encoder_t; | ||
| 74 | |||
| 75 | ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent); | ||
| 76 | void ggml_metal_encoder_free(ggml_metal_encoder_t encoder); | ||
| 77 | |||
| 78 | void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name); | ||
| 79 | void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder); | ||
| 80 | |||
| 81 | void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline); | ||
| 82 | |||
| 83 | void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx); | ||
| 84 | void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx); | ||
| 85 | |||
| 86 | void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx); | ||
| 87 | |||
| 88 | void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2); | ||
| 89 | |||
| 90 | void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder); | ||
| 91 | |||
| 92 | void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder); | ||
| 93 | |||
| 94 | // | ||
| 95 | // MTLLibrary wrapper | ||
| 96 | // | ||
| 97 | |||
| 98 | typedef struct ggml_metal_library * ggml_metal_library_t; | ||
| 99 | |||
| 100 | ggml_metal_library_t ggml_metal_library_init (ggml_metal_device_t dev); | ||
| 101 | ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose); | ||
| 102 | |||
| 103 | void ggml_metal_library_free(ggml_metal_library_t lib); | ||
| 104 | |||
| 105 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); | ||
| 106 | struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); | ||
| 107 | |||
| 108 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); | ||
| 109 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); | ||
| 110 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); | ||
| 111 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); | ||
| 112 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); | ||
| 113 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); | ||
| 114 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 115 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); | ||
| 116 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 117 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 118 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 119 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 120 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 121 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 122 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 123 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 124 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 125 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); | ||
| 126 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 127 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 128 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 129 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); | ||
| 130 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 131 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 132 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); | ||
| 133 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 134 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 135 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 136 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 137 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 138 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 139 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 140 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); | ||
| 141 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); | ||
| 142 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 143 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 144 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); | ||
| 145 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 146 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 147 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 148 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 149 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 150 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 151 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 152 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 153 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 154 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 155 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 156 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 157 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 158 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||
| 159 | |||
| 160 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( | ||
| 161 | ggml_metal_library_t lib, | ||
| 162 | const struct ggml_tensor * op, | ||
| 163 | bool has_mask, | ||
| 164 | int32_t ncpsg); | ||
| 165 | |||
| 166 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk( | ||
| 167 | ggml_metal_library_t lib, | ||
| 168 | const struct ggml_tensor * op, | ||
| 169 | int32_t nqptg, | ||
| 170 | int32_t ncpsg); | ||
| 171 | |||
| 172 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext( | ||
| 173 | ggml_metal_library_t lib, | ||
| 174 | const struct ggml_tensor * op, | ||
| 175 | bool has_mask, | ||
| 176 | bool has_sinks, | ||
| 177 | bool has_bias, | ||
| 178 | bool has_scap, | ||
| 179 | bool has_kvpad, | ||
| 180 | int32_t nsg); | ||
| 181 | |||
| 182 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec( | ||
| 183 | ggml_metal_library_t lib, | ||
| 184 | const struct ggml_tensor * op, | ||
| 185 | bool has_mask, | ||
| 186 | bool has_sinks, | ||
| 187 | bool has_bias, | ||
| 188 | bool has_scap, | ||
| 189 | bool has_kvpad, | ||
| 190 | int32_t nsg, | ||
| 191 | int32_t nwg); | ||
| 192 | |||
| 193 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce( | ||
| 194 | ggml_metal_library_t lib, | ||
| 195 | const struct ggml_tensor * op, | ||
| 196 | int32_t dv, | ||
| 197 | int32_t nwg); | ||
| 198 | |||
| 199 | // MTLResidencySet wrapper | ||
| 200 | |||
| 201 | typedef void * ggml_metal_rset_t; | ||
| 202 | |||
| 203 | // a collection of residency sets (non-owning) | ||
| 204 | typedef struct ggml_metal_rsets * ggml_metal_rsets_t; | ||
| 205 | |||
| 206 | ggml_metal_rsets_t ggml_metal_rsets_init(void); | ||
| 207 | void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); | ||
| 208 | |||
| 209 | // | ||
| 210 | // device | ||
| 211 | // | ||
| 212 | |||
| 213 | struct ggml_metal_device_props { | ||
| 214 | int device; | ||
| 215 | char name[128]; | ||
| 216 | char desc[128]; | ||
| 217 | |||
| 218 | size_t max_buffer_size; | ||
| 219 | size_t max_working_set_size; | ||
| 220 | size_t max_theadgroup_memory_size; | ||
| 221 | |||
| 222 | bool has_simdgroup_reduction; | ||
| 223 | bool has_simdgroup_mm; | ||
| 224 | bool has_unified_memory; | ||
| 225 | bool has_bfloat; | ||
| 226 | bool has_tensor; | ||
| 227 | bool use_residency_sets; | ||
| 228 | bool use_shared_buffers; | ||
| 229 | |||
| 230 | bool supports_gpu_family_apple7; | ||
| 231 | |||
| 232 | int op_offload_min_batch_size; | ||
| 233 | }; | ||
| 234 | |||
| 235 | typedef struct ggml_metal_event * ggml_metal_event_t; | ||
| 236 | |||
| 237 | void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); | ||
| 238 | void ggml_metal_event_encode_wait (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); | ||
| 239 | |||
| 240 | ggml_metal_device_t ggml_metal_device_init(int device); | ||
| 241 | void ggml_metal_device_free(ggml_metal_device_t dev); | ||
| 242 | |||
| 243 | ggml_metal_device_t ggml_metal_device_get(int device); | ||
| 244 | |||
| 245 | void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id<MTLDevice> | ||
| 246 | void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue> | ||
| 247 | |||
| 248 | ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev); | ||
| 249 | |||
| 250 | void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset); | ||
| 251 | void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset); | ||
| 252 | |||
| 253 | void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev); | ||
| 254 | |||
| 255 | ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev); | ||
| 256 | void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev); | ||
| 257 | void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev); | ||
| 258 | |||
| 259 | void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total); | ||
| 260 | bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op); | ||
| 261 | |||
| 262 | const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev); | ||
| 263 | |||
| 264 | // | ||
| 265 | // device buffers | ||
| 266 | // | ||
| 267 | |||
| 268 | typedef struct ggml_metal_buffer * ggml_metal_buffer_t; | ||
| 269 | |||
| 270 | ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared); | ||
| 271 | ggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size); | ||
| 272 | |||
| 273 | void ggml_metal_buffer_free (ggml_metal_buffer_t buf); | ||
| 274 | void * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf); | ||
| 275 | bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf); | ||
| 276 | |||
| 277 | void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); | ||
| 278 | void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); | ||
| 279 | void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); | ||
| 280 | void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value); | ||
| 281 | |||
| 282 | // finds the Metal buffer that contains the tensor data on the GPU device | ||
| 283 | // the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the | ||
| 284 | // Metal buffer based on the host memory pointer | ||
| 285 | // | ||
| 286 | struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t); | ||
| 287 | |||
| 288 | #ifdef __cplusplus | ||
| 289 | } | ||
| 290 | #endif | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m new file mode 100644 index 0000000..4ea0bfb --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-device.m | |||
| @@ -0,0 +1,1748 @@ | |||
| 1 | #import "ggml-metal-device.h" | ||
| 2 | |||
| 3 | #import "ggml-impl.h" | ||
| 4 | |||
| 5 | #include <Foundation/Foundation.h> | ||
| 6 | |||
| 7 | #include <Metal/Metal.h> | ||
| 8 | |||
| 9 | #include <stdatomic.h> | ||
| 10 | |||
| 11 | #ifndef TARGET_OS_VISION | ||
| 12 | #define TARGET_OS_VISION 0 | ||
| 13 | #endif | ||
| 14 | |||
| 15 | // create residency sets only on macOS >= 15.0 | ||
| 16 | #if !TARGET_CPU_X86_64 && TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \ | ||
| 17 | TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \ | ||
| 18 | TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \ | ||
| 19 | TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000 | ||
| 20 | #define GGML_METAL_HAS_RESIDENCY_SETS 1 | ||
| 21 | #endif | ||
| 22 | |||
| 23 | // overload of MTLGPUFamilyMetalX (not available in some environments) | ||
| 24 | static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; | ||
| 25 | static const NSInteger MTLGPUFamilyMetal4_GGML = 5002; | ||
| 26 | |||
| 27 | #if !GGML_METAL_EMBED_LIBRARY | ||
| 28 | // Here to assist with NSBundle Path Hack | ||
| 29 | @interface GGMLMetalClass : NSObject | ||
| 30 | @end | ||
| 31 | @implementation GGMLMetalClass | ||
| 32 | @end | ||
| 33 | #endif | ||
| 34 | |||
| 35 | // | ||
| 36 | // MTLFunctionConstantValues wrapper | ||
| 37 | // | ||
| 38 | |||
| 39 | struct ggml_metal_cv { | ||
| 40 | MTLFunctionConstantValues * obj; | ||
| 41 | }; | ||
| 42 | |||
| 43 | ggml_metal_cv_t ggml_metal_cv_init(void) { | ||
| 44 | ggml_metal_cv_t res = calloc(1, sizeof(struct ggml_metal_cv)); | ||
| 45 | |||
| 46 | res->obj = [[MTLFunctionConstantValues alloc] init]; | ||
| 47 | |||
| 48 | return res; | ||
| 49 | } | ||
| 50 | |||
| 51 | void ggml_metal_cv_free(ggml_metal_cv_t cv) { | ||
| 52 | [cv->obj release]; | ||
| 53 | free(cv); | ||
| 54 | } | ||
| 55 | |||
| 56 | void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) { | ||
| 57 | [cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx]; | ||
| 58 | } | ||
| 59 | |||
| 60 | void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) { | ||
| 61 | [cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx]; | ||
| 62 | } | ||
| 63 | |||
| 64 | void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) { | ||
| 65 | [cv->obj setConstantValue:&value type:MTLDataTypeBool atIndex:idx]; | ||
| 66 | } | ||
| 67 | |||
| 68 | // | ||
| 69 | // MTLComputePipelineState wrapper | ||
| 70 | // | ||
| 71 | |||
| 72 | struct ggml_metal_pipeline { | ||
| 73 | id<MTLComputePipelineState> obj; | ||
| 74 | }; | ||
| 75 | |||
| 76 | ggml_metal_pipeline_t ggml_metal_pipeline_init(void) { | ||
| 77 | ggml_metal_pipeline_t res = calloc(1, sizeof(struct ggml_metal_pipeline)); | ||
| 78 | |||
| 79 | *res = (struct ggml_metal_pipeline) { | ||
| 80 | /*.obj =*/ nil, | ||
| 81 | }; | ||
| 82 | |||
| 83 | return res; | ||
| 84 | } | ||
| 85 | |||
| 86 | void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) { | ||
| 87 | [pipeline->obj release]; | ||
| 88 | |||
| 89 | free(pipeline); | ||
| 90 | } | ||
| 91 | |||
| 92 | int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) { | ||
| 93 | return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup; | ||
| 94 | } | ||
| 95 | |||
| 96 | struct ggml_metal_library { | ||
| 97 | id<MTLLibrary> obj; | ||
| 98 | id<MTLDevice> device; | ||
| 99 | |||
| 100 | ggml_metal_pipelines_t pipelines; // cache of compiled pipelines | ||
| 101 | |||
| 102 | NSLock * lock; | ||
| 103 | }; | ||
| 104 | |||
| 105 | ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { | ||
| 106 | id<MTLLibrary> library = nil; | ||
| 107 | id<MTLDevice> device = ggml_metal_device_get_obj(dev); | ||
| 108 | |||
| 109 | // load library | ||
| 110 | // | ||
| 111 | // - first check if the library is embedded | ||
| 112 | // - then check if the library is in the bundle | ||
| 113 | // - if not found, load the source and compile it | ||
| 114 | // - if that fails, return NULL | ||
| 115 | // | ||
| 116 | // TODO: move to a function | ||
| 117 | { | ||
| 118 | const int64_t t_start = ggml_time_us(); | ||
| 119 | |||
| 120 | NSError * error = nil; | ||
| 121 | NSString * src = nil; | ||
| 122 | |||
| 123 | #if GGML_METAL_EMBED_LIBRARY | ||
| 124 | GGML_LOG_INFO("%s: using embedded metal library\n", __func__); | ||
| 125 | |||
| 126 | extern const char ggml_metallib_start[]; | ||
| 127 | extern const char ggml_metallib_end[]; | ||
| 128 | |||
| 129 | src = [[NSString alloc] initWithBytes:ggml_metallib_start length:(ggml_metallib_end-ggml_metallib_start) encoding:NSUTF8StringEncoding]; | ||
| 130 | #else | ||
| 131 | |||
| 132 | #ifdef SWIFT_PACKAGE | ||
| 133 | NSBundle * bundle = SWIFTPM_MODULE_BUNDLE; | ||
| 134 | #else | ||
| 135 | NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; | ||
| 136 | #endif | ||
| 137 | |||
| 138 | NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; | ||
| 139 | if (path_lib == nil) { | ||
| 140 | // Try to find the resource in the directory where the current binary located. | ||
| 141 | NSString * bin_cur = [[NSProcessInfo processInfo] arguments][0]; | ||
| 142 | NSString * bin_dir = [bin_cur stringByDeletingLastPathComponent]; | ||
| 143 | |||
| 144 | NSString * path_lib_default = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; | ||
| 145 | if ([[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) { | ||
| 146 | GGML_LOG_INFO("%s: found '%s'\n", __func__, [path_lib_default UTF8String]); | ||
| 147 | |||
| 148 | NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:path_lib_default error:&error]; | ||
| 149 | if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { | ||
| 150 | // Optionally, if this is a symlink, try to resolve it. | ||
| 151 | path_lib_default = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:path_lib_default error:&error]; | ||
| 152 | if (path_lib_default && [path_lib_default length] > 0 && ![[path_lib_default substringToIndex:1] isEqualToString:@"/"]) { | ||
| 153 | // It is a relative path, adding the binary directory as directory prefix. | ||
| 154 | path_lib_default = [NSString pathWithComponents:@[bin_dir, path_lib_default]]; | ||
| 155 | } | ||
| 156 | if (!path_lib_default || ![[NSFileManager defaultManager] isReadableFileAtPath:path_lib_default]) { | ||
| 157 | // Link to the resource could not be resolved. | ||
| 158 | path_lib_default = nil; | ||
| 159 | } else { | ||
| 160 | GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [path_lib_default UTF8String]); | ||
| 161 | } | ||
| 162 | } | ||
| 163 | } else { | ||
| 164 | // The resource couldn't be found in the binary's directory. | ||
| 165 | path_lib_default = nil; | ||
| 166 | } | ||
| 167 | |||
| 168 | path_lib = path_lib_default; | ||
| 169 | } | ||
| 170 | |||
| 171 | if (path_lib != nil) { | ||
| 172 | // pre-compiled library found | ||
| 173 | NSURL * libURL = [NSURL fileURLWithPath:path_lib]; | ||
| 174 | GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]); | ||
| 175 | |||
| 176 | library = [device newLibraryWithURL:libURL error:&error]; | ||
| 177 | if (error) { | ||
| 178 | GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); | ||
| 179 | return nil; | ||
| 180 | } | ||
| 181 | } else { | ||
| 182 | GGML_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); | ||
| 183 | |||
| 184 | NSString * path_source; | ||
| 185 | NSString * path_resource = [[NSProcessInfo processInfo].environment objectForKey:@"GGML_METAL_PATH_RESOURCES"]; | ||
| 186 | |||
| 187 | GGML_LOG_INFO("%s: GGML_METAL_PATH_RESOURCES = %s\n", __func__, path_resource ? [path_resource UTF8String] : "nil"); | ||
| 188 | |||
| 189 | if (path_resource) { | ||
| 190 | path_source = [path_resource stringByAppendingPathComponent:@"ggml-metal.metal"]; | ||
| 191 | } else { | ||
| 192 | path_source = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; | ||
| 193 | } | ||
| 194 | |||
| 195 | if (path_source == nil) { | ||
| 196 | GGML_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); | ||
| 197 | path_source = @"ggml-metal.metal"; | ||
| 198 | } | ||
| 199 | |||
| 200 | GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_source UTF8String]); | ||
| 201 | |||
| 202 | src = [NSString stringWithContentsOfFile:path_source encoding:NSUTF8StringEncoding error:&error]; | ||
| 203 | if (error) { | ||
| 204 | GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); | ||
| 205 | return nil; | ||
| 206 | } | ||
| 207 | } | ||
| 208 | #endif | ||
| 209 | |||
| 210 | if (!library) { | ||
| 211 | @autoreleasepool { | ||
| 212 | // dictionary of preprocessor macros | ||
| 213 | NSMutableDictionary * prep = [NSMutableDictionary dictionary]; | ||
| 214 | |||
| 215 | if (ggml_metal_device_get_props(dev)->has_bfloat) { | ||
| 216 | [prep setObject:@"1" forKey:@"GGML_METAL_HAS_BF16"]; | ||
| 217 | } | ||
| 218 | |||
| 219 | if (ggml_metal_device_get_props(dev)->has_tensor) { | ||
| 220 | [prep setObject:@"1" forKey:@"GGML_METAL_HAS_TENSOR"]; | ||
| 221 | } | ||
| 222 | |||
| 223 | #if GGML_METAL_EMBED_LIBRARY | ||
| 224 | [prep setObject:@"1" forKey:@"GGML_METAL_EMBED_LIBRARY"]; | ||
| 225 | #endif | ||
| 226 | |||
| 227 | MTLCompileOptions * options = [MTLCompileOptions new]; | ||
| 228 | options.preprocessorMacros = prep; | ||
| 229 | |||
| 230 | //[options setFastMathEnabled:false]; | ||
| 231 | |||
| 232 | library = [device newLibraryWithSource:src options:options error:&error]; | ||
| 233 | if (error) { | ||
| 234 | GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); | ||
| 235 | return nil; | ||
| 236 | } | ||
| 237 | |||
| 238 | #if !__has_feature(objc_arc) | ||
| 239 | [options release]; | ||
| 240 | #endif | ||
| 241 | } | ||
| 242 | } | ||
| 243 | |||
| 244 | #if GGML_METAL_EMBED_LIBRARY | ||
| 245 | [src release]; | ||
| 246 | #endif // GGML_METAL_EMBED_LIBRARY | ||
| 247 | |||
| 248 | GGML_LOG_INFO("%s: loaded in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); | ||
| 249 | } | ||
| 250 | |||
| 251 | ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); | ||
| 252 | |||
| 253 | res->obj = library; | ||
| 254 | res->device = device; | ||
| 255 | res->pipelines = ggml_metal_pipelines_init(); | ||
| 256 | res->lock = [NSLock new]; | ||
| 257 | |||
| 258 | return res; | ||
| 259 | } | ||
| 260 | |||
| 261 | ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose) { | ||
| 262 | if (source == NULL) { | ||
| 263 | GGML_LOG_ERROR("%s: source is NULL\n", __func__); | ||
| 264 | return NULL; | ||
| 265 | } | ||
| 266 | |||
| 267 | id<MTLDevice> device = ggml_metal_device_get_obj(dev); | ||
| 268 | id<MTLLibrary> library = nil; | ||
| 269 | NSError * error = nil; | ||
| 270 | |||
| 271 | const int64_t t_start = ggml_time_us(); | ||
| 272 | |||
| 273 | NSString * src = [[NSString alloc] initWithBytes:source | ||
| 274 | length:strlen(source) | ||
| 275 | encoding:NSUTF8StringEncoding]; | ||
| 276 | if (!src) { | ||
| 277 | GGML_LOG_ERROR("%s: failed to create NSString from source\n", __func__); | ||
| 278 | return NULL; | ||
| 279 | } | ||
| 280 | |||
| 281 | @autoreleasepool { | ||
| 282 | NSMutableDictionary * prep = [NSMutableDictionary dictionary]; | ||
| 283 | |||
| 284 | MTLCompileOptions * options = [MTLCompileOptions new]; | ||
| 285 | options.preprocessorMacros = prep; | ||
| 286 | |||
| 287 | library = [device newLibraryWithSource:src options:options error:&error]; | ||
| 288 | if (error) { | ||
| 289 | if (verbose) { | ||
| 290 | GGML_LOG_ERROR("%s: error compiling source: %s\n", __func__, [[error description] UTF8String]); | ||
| 291 | } else { | ||
| 292 | GGML_LOG_ERROR("%s: error compiling source\n", __func__); | ||
| 293 | } | ||
| 294 | library = nil; | ||
| 295 | } | ||
| 296 | |||
| 297 | [options release]; | ||
| 298 | } | ||
| 299 | |||
| 300 | [src release]; | ||
| 301 | |||
| 302 | if (!library) { | ||
| 303 | if (verbose) { | ||
| 304 | GGML_LOG_ERROR("%s: failed to create Metal library from source\n", __func__); | ||
| 305 | } | ||
| 306 | |||
| 307 | return NULL; | ||
| 308 | } | ||
| 309 | |||
| 310 | if (verbose) { | ||
| 311 | GGML_LOG_INFO("%s: compiled in %.3f sec\n", __func__, (ggml_time_us() - t_start) / 1e6); | ||
| 312 | } | ||
| 313 | |||
| 314 | ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); | ||
| 315 | if (!res) { | ||
| 316 | GGML_LOG_ERROR("%s: calloc failed\n", __func__); | ||
| 317 | return NULL; | ||
| 318 | } | ||
| 319 | |||
| 320 | res->obj = library; | ||
| 321 | res->device = device; | ||
| 322 | res->pipelines = ggml_metal_pipelines_init(); | ||
| 323 | res->lock = [NSLock new]; | ||
| 324 | |||
| 325 | return res; | ||
| 326 | } | ||
| 327 | |||
| 328 | void ggml_metal_library_free(ggml_metal_library_t lib) { | ||
| 329 | if (!lib) { | ||
| 330 | return; | ||
| 331 | } | ||
| 332 | |||
| 333 | if (lib->obj) { | ||
| 334 | [lib->obj release]; | ||
| 335 | } | ||
| 336 | |||
| 337 | ggml_metal_pipelines_free(lib->pipelines); | ||
| 338 | |||
| 339 | [lib->lock release]; | ||
| 340 | |||
| 341 | free(lib); | ||
| 342 | } | ||
| 343 | |||
| 344 | struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { | ||
| 345 | [lib->lock lock]; | ||
| 346 | |||
| 347 | struct ggml_metal_pipeline_with_params res = { | ||
| 348 | /*.pipeline =*/ nil, | ||
| 349 | /*.nsg =*/ 0, | ||
| 350 | /*.nr0 =*/ 0, | ||
| 351 | /*.nr1 =*/ 0, | ||
| 352 | /*.smem =*/ 0, | ||
| 353 | /*.c4 =*/ false, | ||
| 354 | /*.cnt =*/ false, | ||
| 355 | }; | ||
| 356 | |||
| 357 | res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); | ||
| 358 | |||
| 359 | [lib->lock unlock]; | ||
| 360 | |||
| 361 | return res; | ||
| 362 | } | ||
| 363 | |||
| 364 | struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { | ||
| 365 | struct ggml_metal_pipeline_with_params res = { | ||
| 366 | /*.pipeline =*/ nil, | ||
| 367 | /*.nsg =*/ 0, | ||
| 368 | /*.nr0 =*/ 0, | ||
| 369 | /*.nr1 =*/ 0, | ||
| 370 | /*.smem =*/ 0, | ||
| 371 | /*.c4 =*/ false, | ||
| 372 | /*.cnt =*/ false, | ||
| 373 | }; | ||
| 374 | |||
| 375 | [lib->lock lock]; | ||
| 376 | |||
| 377 | res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); | ||
| 378 | if (res.pipeline) { | ||
| 379 | [lib->lock unlock]; | ||
| 380 | |||
| 381 | return res; | ||
| 382 | } | ||
| 383 | |||
| 384 | @autoreleasepool { | ||
| 385 | NSError * error = nil; | ||
| 386 | |||
| 387 | NSString * base_func = [NSString stringWithUTF8String:base]; | ||
| 388 | |||
| 389 | GGML_LOG_DEBUG("%s: compiling pipeline: base = '%s', name = '%s'\n", __func__, base, name); | ||
| 390 | |||
| 391 | id<MTLFunction> mtl_function; | ||
| 392 | if (!cv) { | ||
| 393 | mtl_function = [lib->obj newFunctionWithName:base_func]; | ||
| 394 | } else { | ||
| 395 | mtl_function = [lib->obj newFunctionWithName:base_func constantValues:cv->obj error:&error]; | ||
| 396 | } | ||
| 397 | if (!mtl_function) { | ||
| 398 | [lib->lock unlock]; | ||
| 399 | |||
| 400 | GGML_LOG_ERROR("%s: failed to compile pipeline: base = '%s', name = '%s'\n", __func__, base, name); | ||
| 401 | if (error) { | ||
| 402 | GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); | ||
| 403 | } | ||
| 404 | |||
| 405 | return res; | ||
| 406 | } | ||
| 407 | |||
| 408 | id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; | ||
| 409 | |||
| 410 | [mtl_function release]; | ||
| 411 | |||
| 412 | if (!obj) { | ||
| 413 | [lib->lock unlock]; | ||
| 414 | |||
| 415 | GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name); | ||
| 416 | if (error) { | ||
| 417 | GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]); | ||
| 418 | } | ||
| 419 | |||
| 420 | return res; | ||
| 421 | } | ||
| 422 | |||
| 423 | GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, | ||
| 424 | (void *) obj, | ||
| 425 | (int) obj.maxTotalThreadsPerThreadgroup, | ||
| 426 | (int) obj.threadExecutionWidth); | ||
| 427 | |||
| 428 | if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) { | ||
| 429 | [obj release]; | ||
| 430 | |||
| 431 | [lib->lock unlock]; | ||
| 432 | |||
| 433 | GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name); | ||
| 434 | |||
| 435 | return res; | ||
| 436 | } | ||
| 437 | |||
| 438 | res.pipeline = ggml_metal_pipeline_init(); | ||
| 439 | res.pipeline->obj = obj; | ||
| 440 | |||
| 441 | ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline); | ||
| 442 | } | ||
| 443 | |||
| 444 | [lib->lock unlock]; | ||
| 445 | |||
| 446 | return res; | ||
| 447 | } | ||
| 448 | |||
| 449 | // | ||
| 450 | // MTLComputeCommandEncoder wrapper | ||
| 451 | // | ||
| 452 | |||
| 453 | struct ggml_metal_encoder { | ||
| 454 | id<MTLComputeCommandEncoder> obj; | ||
| 455 | }; | ||
| 456 | |||
| 457 | ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent) { | ||
| 458 | ggml_metal_encoder_t res = calloc(1, sizeof(struct ggml_metal_encoder)); | ||
| 459 | |||
| 460 | id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw; | ||
| 461 | |||
| 462 | if (concurrent) { | ||
| 463 | res->obj = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent]; | ||
| 464 | } else { | ||
| 465 | res->obj = [cmd_buf computeCommandEncoder]; | ||
| 466 | } | ||
| 467 | |||
| 468 | [res->obj retain]; | ||
| 469 | |||
| 470 | return res; | ||
| 471 | } | ||
| 472 | |||
| 473 | void ggml_metal_encoder_free(ggml_metal_encoder_t encoder) { | ||
| 474 | [encoder->obj release]; | ||
| 475 | free(encoder); | ||
| 476 | } | ||
| 477 | |||
| 478 | void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name) { | ||
| 479 | [encoder->obj pushDebugGroup:[NSString stringWithCString:name encoding:NSUTF8StringEncoding]]; | ||
| 480 | } | ||
| 481 | |||
| 482 | void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) { | ||
| 483 | [encoder->obj popDebugGroup]; | ||
| 484 | } | ||
| 485 | |||
| 486 | void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) { | ||
| 487 | [encoder->obj setComputePipelineState:pipeline.pipeline->obj]; | ||
| 488 | } | ||
| 489 | |||
| 490 | void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) { | ||
| 491 | [encoder->obj setBytes:data length:size atIndex:idx]; | ||
| 492 | } | ||
| 493 | |||
| 494 | void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx) { | ||
| 495 | [encoder->obj setBuffer:buffer.metal offset:buffer.offs atIndex:idx]; | ||
| 496 | } | ||
| 497 | |||
| 498 | void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx) { | ||
| 499 | [encoder->obj setThreadgroupMemoryLength:size atIndex:idx]; | ||
| 500 | } | ||
| 501 | |||
| 502 | void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2) { | ||
| 503 | [encoder->obj dispatchThreadgroups:MTLSizeMake(tg0, tg1, tg2) threadsPerThreadgroup:MTLSizeMake(tptg0, tptg1, tptg2)]; | ||
| 504 | } | ||
| 505 | |||
| 506 | void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder) { | ||
| 507 | [encoder->obj memoryBarrierWithScope:MTLBarrierScopeBuffers]; | ||
| 508 | } | ||
| 509 | |||
| 510 | void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { | ||
| 511 | [encoder->obj endEncoding]; | ||
| 512 | } | ||
| 513 | |||
| 514 | struct ggml_metal_device { | ||
| 515 | id<MTLDevice> mtl_device; | ||
| 516 | |||
| 517 | // a single global queue shared by all Metal backends | ||
| 518 | // technically not needed for devices with unified memory, but enables discrete GPUs support | ||
| 519 | // ref: https://github.com/ggml-org/llama.cpp/pull/15906 | ||
| 520 | id<MTLCommandQueue> mtl_queue; | ||
| 521 | |||
| 522 | ggml_metal_rsets_t rsets; | ||
| 523 | |||
| 524 | ggml_metal_library_t library; | ||
| 525 | |||
| 526 | struct ggml_metal_device_props props; | ||
| 527 | |||
| 528 | // virtual address for GPU memory allocations | ||
| 529 | atomic_uintptr_t addr_virt; | ||
| 530 | }; | ||
| 531 | |||
| 532 | // | ||
| 533 | // MTLResidenceSet wrapper | ||
| 534 | // | ||
| 535 | |||
| 536 | struct ggml_metal_rsets { | ||
| 537 | NSLock * lock; | ||
| 538 | |||
| 539 | NSMutableArray * data; | ||
| 540 | |||
| 541 | // number of seconds since the last graph computation | ||
| 542 | // keep the residency sets wired for that amount of time to avoid being collected by the OS | ||
| 543 | int keep_alive_s; | ||
| 544 | |||
| 545 | // background heartbeat thread to keep the residency sets alive | ||
| 546 | atomic_bool d_stop; | ||
| 547 | atomic_int d_loop; | ||
| 548 | |||
| 549 | dispatch_group_t d_group; | ||
| 550 | }; | ||
| 551 | |||
| 552 | ggml_metal_rsets_t ggml_metal_rsets_init(void) { | ||
| 553 | ggml_metal_rsets_t res = calloc(1, sizeof(struct ggml_metal_rsets)); | ||
| 554 | |||
| 555 | res->lock = [[NSLock alloc] init]; | ||
| 556 | res->data = [[NSMutableArray alloc] init]; | ||
| 557 | |||
| 558 | // by default keep the memory wired for 3 minutes | ||
| 559 | res->keep_alive_s = 3*60; | ||
| 560 | |||
| 561 | const char * GGML_METAL_RESIDENCY_KEEP_ALIVE_S = getenv("GGML_METAL_RESIDENCY_KEEP_ALIVE_S"); | ||
| 562 | if (GGML_METAL_RESIDENCY_KEEP_ALIVE_S) { | ||
| 563 | res->keep_alive_s = atoi(GGML_METAL_RESIDENCY_KEEP_ALIVE_S); | ||
| 564 | } | ||
| 565 | |||
| 566 | if (res->keep_alive_s <= 0) { | ||
| 567 | res->keep_alive_s = 3*60; | ||
| 568 | } | ||
| 569 | |||
| 570 | GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s); | ||
| 571 | |||
| 572 | atomic_store_explicit(&res->d_stop, false, memory_order_relaxed); | ||
| 573 | atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed); | ||
| 574 | |||
| 575 | res->d_group = dispatch_group_create(); | ||
| 576 | |||
| 577 | // start a background thread that periodically requests residency for all the currently active sets in the collection | ||
| 578 | // the requests stop after a certain amount of time (keep_alive_s) of inactivity | ||
| 579 | dispatch_queue_t d_queue = dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0); | ||
| 580 | dispatch_group_async(res->d_group, d_queue, ^{ | ||
| 581 | #if defined(GGML_METAL_HAS_RESIDENCY_SETS) | ||
| 582 | if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { | ||
| 583 | while (!atomic_load_explicit(&res->d_stop, memory_order_relaxed)) { | ||
| 584 | if (atomic_load_explicit(&res->d_loop, memory_order_relaxed) > 0) { | ||
| 585 | [res->lock lock]; | ||
| 586 | |||
| 587 | for (int i = 0; i < (int) res->data.count; ++i) { | ||
| 588 | [res->data[i] requestResidency]; | ||
| 589 | } | ||
| 590 | |||
| 591 | atomic_fetch_sub_explicit(&res->d_loop, 1, memory_order_relaxed); | ||
| 592 | |||
| 593 | [res->lock unlock]; | ||
| 594 | } | ||
| 595 | |||
| 596 | // half a second | ||
| 597 | usleep(500 * 1000); | ||
| 598 | } | ||
| 599 | } | ||
| 600 | #endif | ||
| 601 | }); | ||
| 602 | |||
| 603 | return res; | ||
| 604 | } | ||
| 605 | |||
| 606 | void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { | ||
| 607 | if (rsets == NULL) { | ||
| 608 | return; | ||
| 609 | } | ||
| 610 | |||
| 611 | // note: if you hit this assert, most likely you haven't deallocated all Metal resources before exiting | ||
| 612 | GGML_ASSERT([rsets->data count] == 0); | ||
| 613 | |||
| 614 | atomic_store_explicit(&rsets->d_stop, true, memory_order_relaxed); | ||
| 615 | |||
| 616 | dispatch_group_wait(rsets->d_group, DISPATCH_TIME_FOREVER); | ||
| 617 | dispatch_release(rsets->d_group); | ||
| 618 | |||
| 619 | [rsets->data release]; | ||
| 620 | [rsets->lock release]; | ||
| 621 | |||
| 622 | free(rsets); | ||
| 623 | } | ||
| 624 | |||
| 625 | ggml_metal_device_t ggml_metal_device_init(int device) { | ||
| 626 | ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); | ||
| 627 | |||
| 628 | assert(dev != NULL); | ||
| 629 | |||
| 630 | if (dev->mtl_device == nil) { | ||
| 631 | dev->mtl_device = MTLCreateSystemDefaultDevice(); | ||
| 632 | |||
| 633 | if (dev->mtl_device) { | ||
| 634 | dev->mtl_queue = [dev->mtl_device newCommandQueue]; | ||
| 635 | if (dev->mtl_queue == nil) { | ||
| 636 | GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); | ||
| 637 | } | ||
| 638 | |||
| 639 | dev->addr_virt = 0x000000400ULL; | ||
| 640 | |||
| 641 | dev->props.device = device; | ||
| 642 | dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; | ||
| 643 | dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; | ||
| 644 | |||
| 645 | dev->props.has_simdgroup_mm = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; | ||
| 646 | dev->props.has_unified_memory = dev->mtl_device.hasUnifiedMemory; | ||
| 647 | |||
| 648 | dev->props.has_bfloat = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; | ||
| 649 | dev->props.has_bfloat |= [dev->mtl_device supportsFamily:MTLGPUFamilyApple6]; | ||
| 650 | if (getenv("GGML_METAL_BF16_DISABLE") != NULL) { | ||
| 651 | dev->props.has_bfloat = false; | ||
| 652 | } | ||
| 653 | |||
| 654 | dev->props.has_tensor = [dev->mtl_device supportsFamily:MTLGPUFamilyMetal4_GGML]; | ||
| 655 | if (getenv("GGML_METAL_TENSOR_DISABLE") != NULL) { | ||
| 656 | dev->props.has_tensor = false; | ||
| 657 | } | ||
| 658 | |||
| 659 | // note: disable the tensor API by default for old chips because with the current implementation it is not useful | ||
| 660 | // - M2 Ultra: ~5% slower | ||
| 661 | // - M4, M4 Max: no significant difference | ||
| 662 | // | ||
| 663 | // TODO: try to update the tensor API kernels to at least match the simdgroup performance | ||
| 664 | if (getenv("GGML_METAL_TENSOR_ENABLE") == NULL && | ||
| 665 | ![[dev->mtl_device name] containsString:@"M5"] && | ||
| 666 | ![[dev->mtl_device name] containsString:@"M6"] && | ||
| 667 | ![[dev->mtl_device name] containsString:@"A19"] && | ||
| 668 | ![[dev->mtl_device name] containsString:@"A20"]) { | ||
| 669 | GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); | ||
| 670 | dev->props.has_tensor = false; | ||
| 671 | } | ||
| 672 | |||
| 673 | // double-check that the tensor API compiles | ||
| 674 | if (dev->props.has_tensor) { | ||
| 675 | const char * src_tensor_f16 = "\n" | ||
| 676 | "#include <metal_stdlib> \n" | ||
| 677 | "#include <metal_tensor> \n" | ||
| 678 | "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n" | ||
| 679 | " \n" | ||
| 680 | "using namespace metal; \n" | ||
| 681 | "using namespace mpp::tensor_ops; \n" | ||
| 682 | " \n" | ||
| 683 | "kernel void dummy_kernel( \n" | ||
| 684 | " tensor<device half, dextents<int32_t, 2>> A [[buffer(0)]], \n" | ||
| 685 | " tensor<device half, dextents<int32_t, 2>> B [[buffer(1)]], \n" | ||
| 686 | " device float * C [[buffer(2)]], \n" | ||
| 687 | " uint2 tgid [[threadgroup_position_in_grid]]) \n" | ||
| 688 | "{ \n" | ||
| 689 | " auto tA = A.slice(0, (int)tgid.y); \n" | ||
| 690 | " auto tB = B.slice((int)tgid.x, 0); \n" | ||
| 691 | " \n" | ||
| 692 | " matmul2d< \n" | ||
| 693 | " matmul2d_descriptor(8, 8, dynamic_extent), \n" | ||
| 694 | " execution_simdgroups<4>> mm; \n" | ||
| 695 | " \n" | ||
| 696 | " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n" | ||
| 697 | " \n" | ||
| 698 | " auto sA = tA.slice(0, 0); \n" | ||
| 699 | " auto sB = tB.slice(0, 0); \n" | ||
| 700 | " mm.run(sB, sA, cT); \n" | ||
| 701 | " \n" | ||
| 702 | " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n" | ||
| 703 | " \n" | ||
| 704 | " cT.store(tC); \n" | ||
| 705 | "}"; | ||
| 706 | |||
| 707 | GGML_LOG_INFO("%s: testing tensor API for f16 support\n", __func__); | ||
| 708 | ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_f16, false); | ||
| 709 | if (lib == NULL) { | ||
| 710 | GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); | ||
| 711 | dev->props.has_tensor = false; | ||
| 712 | } else { | ||
| 713 | struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); | ||
| 714 | if (!ppl.pipeline) { | ||
| 715 | GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__); | ||
| 716 | dev->props.has_tensor = false; | ||
| 717 | } | ||
| 718 | |||
| 719 | ggml_metal_library_free(lib); | ||
| 720 | } | ||
| 721 | } | ||
| 722 | |||
| 723 | // try to compile a dummy kernel to determine if the tensor API is supported for bfloat | ||
| 724 | if (dev->props.has_tensor && dev->props.has_bfloat) { | ||
| 725 | const char * src_tensor_bf16 = "\n" | ||
| 726 | "#include <metal_stdlib> \n" | ||
| 727 | "#include <metal_tensor> \n" | ||
| 728 | "#include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> \n" | ||
| 729 | " \n" | ||
| 730 | "using namespace metal; \n" | ||
| 731 | "using namespace mpp::tensor_ops; \n" | ||
| 732 | " \n" | ||
| 733 | "kernel void dummy_kernel( \n" | ||
| 734 | " tensor<device bfloat, dextents<int32_t, 2>> A [[buffer(0)]], \n" | ||
| 735 | " tensor<device bfloat, dextents<int32_t, 2>> B [[buffer(1)]], \n" | ||
| 736 | " device float * C [[buffer(2)]], \n" | ||
| 737 | " uint2 tgid [[threadgroup_position_in_grid]]) \n" | ||
| 738 | "{ \n" | ||
| 739 | " auto tA = A.slice(0, (int)tgid.y); \n" | ||
| 740 | " auto tB = B.slice((int)tgid.x, 0); \n" | ||
| 741 | " \n" | ||
| 742 | " matmul2d< \n" | ||
| 743 | " matmul2d_descriptor(8, 8, dynamic_extent), \n" | ||
| 744 | " execution_simdgroups<4>> mm; \n" | ||
| 745 | " \n" | ||
| 746 | " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n" | ||
| 747 | " \n" | ||
| 748 | " auto sA = tA.slice(0, 0); \n" | ||
| 749 | " auto sB = tB.slice(0, 0); \n" | ||
| 750 | " mm.run(sB, sA, cT); \n" | ||
| 751 | " \n" | ||
| 752 | " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n" | ||
| 753 | " \n" | ||
| 754 | " cT.store(tC); \n" | ||
| 755 | "}"; | ||
| 756 | |||
| 757 | GGML_LOG_INFO("%s: testing tensor API for bfloat support\n", __func__); | ||
| 758 | ggml_metal_library_t lib = ggml_metal_library_init_from_source(dev, src_tensor_bf16, false); | ||
| 759 | if (lib == NULL) { | ||
| 760 | GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); | ||
| 761 | dev->props.has_bfloat = false; | ||
| 762 | } else { | ||
| 763 | struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil); | ||
| 764 | if (!ppl.pipeline) { | ||
| 765 | GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__); | ||
| 766 | dev->props.has_bfloat = false; | ||
| 767 | } | ||
| 768 | |||
| 769 | ggml_metal_library_free(lib); | ||
| 770 | } | ||
| 771 | } | ||
| 772 | |||
| 773 | dev->props.use_residency_sets = true; | ||
| 774 | #if defined(GGML_METAL_HAS_RESIDENCY_SETS) | ||
| 775 | dev->props.use_residency_sets = getenv("GGML_METAL_NO_RESIDENCY") == nil; | ||
| 776 | #endif | ||
| 777 | |||
| 778 | dev->props.use_shared_buffers = dev->props.has_unified_memory; | ||
| 779 | #if TARGET_OS_OSX | ||
| 780 | // In case of eGPU, shared memory may be preferable. | ||
| 781 | dev->props.use_shared_buffers |= [dev->mtl_device location] == MTLDeviceLocationExternal; | ||
| 782 | #endif | ||
| 783 | if (getenv("GGML_METAL_SHARED_BUFFERS_DISABLE") != NULL) { | ||
| 784 | dev->props.use_shared_buffers = false; | ||
| 785 | } | ||
| 786 | if (getenv("GGML_METAL_SHARED_BUFFERS_ENABLE") != NULL) { | ||
| 787 | dev->props.use_shared_buffers = true; | ||
| 788 | } | ||
| 789 | |||
| 790 | dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; | ||
| 791 | |||
| 792 | dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; | ||
| 793 | |||
| 794 | dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; | ||
| 795 | dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; | ||
| 796 | if (@available(macOS 10.12, iOS 16.0, *)) { | ||
| 797 | dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; | ||
| 798 | } else { | ||
| 799 | dev->props.max_working_set_size = dev->mtl_device.maxBufferLength; | ||
| 800 | } | ||
| 801 | |||
| 802 | snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device); | ||
| 803 | snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]); | ||
| 804 | |||
| 805 | dev->library = ggml_metal_library_init(dev); | ||
| 806 | if (!dev->library) { | ||
| 807 | GGML_LOG_ERROR("%s: error: failed to create library\n", __func__); | ||
| 808 | } | ||
| 809 | |||
| 810 | if (dev->props.use_residency_sets) { | ||
| 811 | dev->rsets = ggml_metal_rsets_init(); | ||
| 812 | } else { | ||
| 813 | dev->rsets = nil; | ||
| 814 | } | ||
| 815 | |||
| 816 | // print MTL GPU family: | ||
| 817 | GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); | ||
| 818 | |||
| 819 | // determine max supported GPU family | ||
| 820 | // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf | ||
| 821 | // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf | ||
| 822 | { | ||
| 823 | for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { | ||
| 824 | if ([dev->mtl_device supportsFamily:i]) { | ||
| 825 | GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i); | ||
| 826 | break; | ||
| 827 | } | ||
| 828 | } | ||
| 829 | |||
| 830 | for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) { | ||
| 831 | if ([dev->mtl_device supportsFamily:i]) { | ||
| 832 | GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i); | ||
| 833 | break; | ||
| 834 | } | ||
| 835 | } | ||
| 836 | |||
| 837 | for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) { | ||
| 838 | if ([dev->mtl_device supportsFamily:i]) { | ||
| 839 | GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i); | ||
| 840 | break; | ||
| 841 | } | ||
| 842 | } | ||
| 843 | } | ||
| 844 | |||
| 845 | GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, dev->props.has_simdgroup_reduction ? "true" : "false"); | ||
| 846 | GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, dev->props.has_simdgroup_mm ? "true" : "false"); | ||
| 847 | GGML_LOG_INFO("%s: has unified memory = %s\n", __func__, dev->props.has_unified_memory ? "true" : "false"); | ||
| 848 | GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, dev->props.has_bfloat ? "true" : "false"); | ||
| 849 | GGML_LOG_INFO("%s: has tensor = %s\n", __func__, dev->props.has_tensor ? "true" : "false"); | ||
| 850 | GGML_LOG_INFO("%s: use residency sets = %s\n", __func__, dev->props.use_residency_sets ? "true" : "false"); | ||
| 851 | GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, dev->props.use_shared_buffers ? "true" : "false"); | ||
| 852 | |||
| 853 | #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) | ||
| 854 | if (@available(macOS 10.12, iOS 16.0, *)) { | ||
| 855 | GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, dev->props.max_working_set_size / 1e6); | ||
| 856 | } | ||
| 857 | #endif | ||
| 858 | } | ||
| 859 | } | ||
| 860 | |||
| 861 | return dev; | ||
| 862 | } | ||
| 863 | |||
| 864 | void ggml_metal_device_free(ggml_metal_device_t dev) { | ||
| 865 | assert(dev != NULL); | ||
| 866 | |||
| 867 | ggml_metal_rsets_free(dev->rsets); | ||
| 868 | |||
| 869 | ggml_metal_library_free(dev->library); | ||
| 870 | dev->library = NULL; | ||
| 871 | |||
| 872 | if (dev->mtl_queue) { | ||
| 873 | [dev->mtl_queue release]; | ||
| 874 | dev->mtl_queue = nil; | ||
| 875 | } | ||
| 876 | |||
| 877 | if (dev->mtl_device) { | ||
| 878 | [dev->mtl_device release]; | ||
| 879 | dev->mtl_device = nil; | ||
| 880 | } | ||
| 881 | |||
| 882 | free(dev); | ||
| 883 | } | ||
| 884 | |||
| 885 | void * ggml_metal_device_get_obj(ggml_metal_device_t dev) { | ||
| 886 | return dev->mtl_device; | ||
| 887 | } | ||
| 888 | |||
| 889 | void * ggml_metal_device_get_queue(ggml_metal_device_t dev) { | ||
| 890 | return dev->mtl_queue; | ||
| 891 | } | ||
| 892 | |||
| 893 | ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev) { | ||
| 894 | return dev->library; | ||
| 895 | } | ||
| 896 | |||
| 897 | void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset) { | ||
| 898 | if (rset == nil) { | ||
| 899 | return; | ||
| 900 | } | ||
| 901 | |||
| 902 | GGML_ASSERT(dev->rsets); | ||
| 903 | |||
| 904 | [dev->rsets->lock lock]; | ||
| 905 | |||
| 906 | [dev->rsets->data addObject:rset]; | ||
| 907 | |||
| 908 | [dev->rsets->lock unlock]; | ||
| 909 | } | ||
| 910 | |||
| 911 | void ggml_metal_device_rsets_rm(ggml_metal_device_t dev, ggml_metal_rset_t rset) { | ||
| 912 | if (rset == nil) { | ||
| 913 | return; | ||
| 914 | } | ||
| 915 | |||
| 916 | GGML_ASSERT(dev->rsets); | ||
| 917 | |||
| 918 | [dev->rsets->lock lock]; | ||
| 919 | |||
| 920 | [dev->rsets->data removeObject:rset]; | ||
| 921 | |||
| 922 | [dev->rsets->lock unlock]; | ||
| 923 | } | ||
| 924 | |||
| 925 | void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { | ||
| 926 | if (dev->rsets == NULL) { | ||
| 927 | return; | ||
| 928 | } | ||
| 929 | |||
| 930 | atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); | ||
| 931 | } | ||
| 932 | |||
| 933 | struct ggml_metal_event { | ||
| 934 | void * obj; // id<MTLEvent> | ||
| 935 | |||
| 936 | atomic_int value; | ||
| 937 | }; | ||
| 938 | |||
| 939 | void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { | ||
| 940 | id<MTLEvent> event = (id<MTLEvent>)ev->obj; | ||
| 941 | |||
| 942 | id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw; | ||
| 943 | |||
| 944 | [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1]; | ||
| 945 | } | ||
| 946 | |||
| 947 | void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { | ||
| 948 | id<MTLEvent> event = (id<MTLEvent>)ev->obj; | ||
| 949 | |||
| 950 | id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw; | ||
| 951 | |||
| 952 | [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; | ||
| 953 | } | ||
| 954 | |||
| 955 | ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { | ||
| 956 | id<MTLEvent> event = [dev->mtl_device newEvent]; | ||
| 957 | |||
| 958 | ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); | ||
| 959 | |||
| 960 | ev->obj = (__bridge void *)event; | ||
| 961 | ev->value = 0; | ||
| 962 | |||
| 963 | return ev; | ||
| 964 | } | ||
| 965 | |||
| 966 | void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { | ||
| 967 | id<MTLEvent> event = ev->obj; | ||
| 968 | [event release]; | ||
| 969 | |||
| 970 | free(ev); | ||
| 971 | |||
| 972 | GGML_UNUSED(dev); | ||
| 973 | } | ||
| 974 | |||
| 975 | void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { | ||
| 976 | @autoreleasepool { | ||
| 977 | id<MTLEvent> event = ev->obj; | ||
| 978 | |||
| 979 | id<MTLCommandBuffer> cmd_buf = [dev->mtl_queue commandBuffer]; | ||
| 980 | [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; | ||
| 981 | [cmd_buf commit]; | ||
| 982 | [cmd_buf waitUntilCompleted]; | ||
| 983 | } | ||
| 984 | } | ||
| 985 | |||
| 986 | void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { | ||
| 987 | if (@available(macOS 10.12, iOS 16.0, *)) { | ||
| 988 | *total = dev->mtl_device.recommendedMaxWorkingSetSize; | ||
| 989 | *free = *total - dev->mtl_device.currentAllocatedSize; | ||
| 990 | } else { | ||
| 991 | *free = 0; | ||
| 992 | *total = 0; | ||
| 993 | } | ||
| 994 | } | ||
| 995 | |||
| 996 | bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op) { | ||
| 997 | const bool has_simdgroup_mm = dev->props.has_simdgroup_mm; | ||
| 998 | const bool has_simdgroup_reduction = dev->props.has_simdgroup_reduction; | ||
| 999 | const bool has_bfloat = dev->props.has_bfloat; | ||
| 1000 | |||
| 1001 | if (!has_bfloat) { | ||
| 1002 | if (op->type == GGML_TYPE_BF16) { | ||
| 1003 | return false; | ||
| 1004 | } | ||
| 1005 | |||
| 1006 | for (size_t i = 0, n = 3; i < n; ++i) { | ||
| 1007 | if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { | ||
| 1008 | return false; | ||
| 1009 | } | ||
| 1010 | } | ||
| 1011 | } | ||
| 1012 | |||
| 1013 | switch (op->op) { | ||
| 1014 | case GGML_OP_SCALE: | ||
| 1015 | case GGML_OP_FILL: | ||
| 1016 | case GGML_OP_CLAMP: | ||
| 1017 | case GGML_OP_SQR: | ||
| 1018 | case GGML_OP_SQRT: | ||
| 1019 | case GGML_OP_SIN: | ||
| 1020 | case GGML_OP_COS: | ||
| 1021 | case GGML_OP_LOG: | ||
| 1022 | return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); | ||
| 1023 | case GGML_OP_UNARY: | ||
| 1024 | switch (ggml_get_unary_op(op)) { | ||
| 1025 | case GGML_UNARY_OP_TANH: | ||
| 1026 | case GGML_UNARY_OP_RELU: | ||
| 1027 | case GGML_UNARY_OP_SIGMOID: | ||
| 1028 | case GGML_UNARY_OP_GELU: | ||
| 1029 | case GGML_UNARY_OP_GELU_ERF: | ||
| 1030 | case GGML_UNARY_OP_GELU_QUICK: | ||
| 1031 | case GGML_UNARY_OP_SILU: | ||
| 1032 | case GGML_UNARY_OP_ELU: | ||
| 1033 | case GGML_UNARY_OP_NEG: | ||
| 1034 | case GGML_UNARY_OP_ABS: | ||
| 1035 | case GGML_UNARY_OP_SGN: | ||
| 1036 | case GGML_UNARY_OP_STEP: | ||
| 1037 | case GGML_UNARY_OP_HARDSWISH: | ||
| 1038 | case GGML_UNARY_OP_HARDSIGMOID: | ||
| 1039 | case GGML_UNARY_OP_EXP: | ||
| 1040 | case GGML_UNARY_OP_SOFTPLUS: | ||
| 1041 | case GGML_UNARY_OP_EXPM1: | ||
| 1042 | return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); | ||
| 1043 | default: | ||
| 1044 | return false; | ||
| 1045 | } | ||
| 1046 | case GGML_OP_GLU: | ||
| 1047 | switch (ggml_get_glu_op(op)) { | ||
| 1048 | case GGML_GLU_OP_REGLU: | ||
| 1049 | case GGML_GLU_OP_GEGLU: | ||
| 1050 | case GGML_GLU_OP_SWIGLU: | ||
| 1051 | case GGML_GLU_OP_SWIGLU_OAI: | ||
| 1052 | case GGML_GLU_OP_GEGLU_ERF: | ||
| 1053 | case GGML_GLU_OP_GEGLU_QUICK: | ||
| 1054 | return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; | ||
| 1055 | default: | ||
| 1056 | return false; | ||
| 1057 | } | ||
| 1058 | case GGML_OP_NONE: | ||
| 1059 | case GGML_OP_RESHAPE: | ||
| 1060 | case GGML_OP_VIEW: | ||
| 1061 | case GGML_OP_TRANSPOSE: | ||
| 1062 | case GGML_OP_PERMUTE: | ||
| 1063 | case GGML_OP_CONCAT: | ||
| 1064 | return true; | ||
| 1065 | case GGML_OP_ADD: | ||
| 1066 | case GGML_OP_SUB: | ||
| 1067 | case GGML_OP_MUL: | ||
| 1068 | case GGML_OP_DIV: | ||
| 1069 | case GGML_OP_ADD_ID: | ||
| 1070 | return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; | ||
| 1071 | case GGML_OP_ACC: | ||
| 1072 | case GGML_OP_REPEAT: | ||
| 1073 | case GGML_OP_CONV_TRANSPOSE_1D: | ||
| 1074 | return true; | ||
| 1075 | case GGML_OP_CONV_TRANSPOSE_2D: | ||
| 1076 | return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && | ||
| 1077 | (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && | ||
| 1078 | op->src[1]->type == GGML_TYPE_F32 && | ||
| 1079 | op->type == GGML_TYPE_F32; | ||
| 1080 | case GGML_OP_SUM: | ||
| 1081 | return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); | ||
| 1082 | case GGML_OP_TRI: | ||
| 1083 | return ggml_is_contiguous_rows(op->src[0]); | ||
| 1084 | case GGML_OP_SUM_ROWS: | ||
| 1085 | case GGML_OP_CUMSUM: | ||
| 1086 | case GGML_OP_MEAN: | ||
| 1087 | case GGML_OP_SOFT_MAX: | ||
| 1088 | case GGML_OP_GROUP_NORM: | ||
| 1089 | case GGML_OP_L2_NORM: | ||
| 1090 | return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); | ||
| 1091 | case GGML_OP_COUNT_EQUAL: | ||
| 1092 | return has_simdgroup_reduction && | ||
| 1093 | op->src[0]->type == GGML_TYPE_I32 && | ||
| 1094 | op->src[1]->type == GGML_TYPE_I32 && | ||
| 1095 | op->type == GGML_TYPE_I64; | ||
| 1096 | case GGML_OP_ARGMAX: | ||
| 1097 | return has_simdgroup_reduction; | ||
| 1098 | case GGML_OP_NORM: | ||
| 1099 | case GGML_OP_RMS_NORM: | ||
| 1100 | return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0])); | ||
| 1101 | case GGML_OP_ROPE: | ||
| 1102 | return true; | ||
| 1103 | case GGML_OP_IM2COL: | ||
| 1104 | return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); | ||
| 1105 | case GGML_OP_CONV_2D: | ||
| 1106 | return ggml_is_contiguous(op->src[0]) && | ||
| 1107 | op->src[1]->type == GGML_TYPE_F32 && | ||
| 1108 | op->type == GGML_TYPE_F32 && | ||
| 1109 | (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); | ||
| 1110 | case GGML_OP_UPSCALE: | ||
| 1111 | return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); | ||
| 1112 | case GGML_OP_POOL_1D: | ||
| 1113 | return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; | ||
| 1114 | case GGML_OP_POOL_2D: | ||
| 1115 | return op->src[0]->type == GGML_TYPE_F32; | ||
| 1116 | case GGML_OP_PAD: | ||
| 1117 | // TODO: add circular padding support for metal, see https://github.com/ggml-org/llama.cpp/pull/16985 | ||
| 1118 | if (ggml_get_op_params_i32(op, 8) != 0) { | ||
| 1119 | return false; | ||
| 1120 | } | ||
| 1121 | |||
| 1122 | return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && | ||
| 1123 | (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); | ||
| 1124 | case GGML_OP_PAD_REFLECT_1D: | ||
| 1125 | case GGML_OP_TIMESTEP_EMBEDDING: | ||
| 1126 | case GGML_OP_LEAKY_RELU: | ||
| 1127 | return op->src[0]->type == GGML_TYPE_F32; | ||
| 1128 | case GGML_OP_ARGSORT: | ||
| 1129 | case GGML_OP_TOP_K: | ||
| 1130 | case GGML_OP_ARANGE: | ||
| 1131 | return true; | ||
| 1132 | case GGML_OP_FLASH_ATTN_EXT: | ||
| 1133 | // for new head sizes, add checks here | ||
| 1134 | if (op->src[0]->ne[0] != 32 && | ||
| 1135 | op->src[0]->ne[0] != 40 && | ||
| 1136 | op->src[0]->ne[0] != 48 && | ||
| 1137 | op->src[0]->ne[0] != 64 && | ||
| 1138 | op->src[0]->ne[0] != 72 && | ||
| 1139 | op->src[0]->ne[0] != 80 && | ||
| 1140 | op->src[0]->ne[0] != 96 && | ||
| 1141 | op->src[0]->ne[0] != 112 && | ||
| 1142 | op->src[0]->ne[0] != 128 && | ||
| 1143 | op->src[0]->ne[0] != 192 && | ||
| 1144 | op->src[0]->ne[0] != 256 && | ||
| 1145 | op->src[0]->ne[0] != 576) { | ||
| 1146 | return false; | ||
| 1147 | } | ||
| 1148 | if (op->src[1]->type != op->src[2]->type) { | ||
| 1149 | return false; | ||
| 1150 | } | ||
| 1151 | return has_simdgroup_mm; // TODO: over-restricted for vec-kernels | ||
| 1152 | case GGML_OP_SSM_CONV: | ||
| 1153 | case GGML_OP_SSM_SCAN: | ||
| 1154 | return has_simdgroup_reduction; | ||
| 1155 | case GGML_OP_RWKV_WKV6: | ||
| 1156 | case GGML_OP_RWKV_WKV7: | ||
| 1157 | return true; | ||
| 1158 | case GGML_OP_SOLVE_TRI: | ||
| 1159 | case GGML_OP_MUL_MAT: | ||
| 1160 | case GGML_OP_MUL_MAT_ID: | ||
| 1161 | return has_simdgroup_reduction; | ||
| 1162 | case GGML_OP_CPY: | ||
| 1163 | case GGML_OP_DUP: | ||
| 1164 | case GGML_OP_CONT: | ||
| 1165 | { | ||
| 1166 | switch (op->src[0]->type) { | ||
| 1167 | case GGML_TYPE_F32: | ||
| 1168 | switch (op->type) { | ||
| 1169 | case GGML_TYPE_F32: | ||
| 1170 | case GGML_TYPE_F16: | ||
| 1171 | case GGML_TYPE_BF16: | ||
| 1172 | case GGML_TYPE_Q8_0: | ||
| 1173 | case GGML_TYPE_Q4_0: | ||
| 1174 | case GGML_TYPE_Q4_1: | ||
| 1175 | case GGML_TYPE_Q5_0: | ||
| 1176 | case GGML_TYPE_Q5_1: | ||
| 1177 | case GGML_TYPE_IQ4_NL: | ||
| 1178 | case GGML_TYPE_I32: | ||
| 1179 | return true; | ||
| 1180 | default: | ||
| 1181 | return false; | ||
| 1182 | } | ||
| 1183 | case GGML_TYPE_F16: | ||
| 1184 | switch (op->type) { | ||
| 1185 | case GGML_TYPE_F32: | ||
| 1186 | case GGML_TYPE_F16: | ||
| 1187 | return true; | ||
| 1188 | default: | ||
| 1189 | return false; | ||
| 1190 | } | ||
| 1191 | case GGML_TYPE_BF16: | ||
| 1192 | switch (op->type) { | ||
| 1193 | case GGML_TYPE_F32: | ||
| 1194 | case GGML_TYPE_BF16: | ||
| 1195 | return true; | ||
| 1196 | default: | ||
| 1197 | return false; | ||
| 1198 | } | ||
| 1199 | case GGML_TYPE_Q4_0: | ||
| 1200 | case GGML_TYPE_Q4_1: | ||
| 1201 | case GGML_TYPE_Q5_0: | ||
| 1202 | case GGML_TYPE_Q5_1: | ||
| 1203 | case GGML_TYPE_Q8_0: | ||
| 1204 | switch (op->type) { | ||
| 1205 | case GGML_TYPE_F32: | ||
| 1206 | case GGML_TYPE_F16: | ||
| 1207 | return true; | ||
| 1208 | default: | ||
| 1209 | return false; | ||
| 1210 | } | ||
| 1211 | case GGML_TYPE_I32: | ||
| 1212 | return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32; | ||
| 1213 | default: | ||
| 1214 | return false; | ||
| 1215 | }; | ||
| 1216 | } | ||
| 1217 | case GGML_OP_GET_ROWS: | ||
| 1218 | return true; | ||
| 1219 | case GGML_OP_SET_ROWS: | ||
| 1220 | { | ||
| 1221 | if (op->src[0]->type != GGML_TYPE_F32) { | ||
| 1222 | return false; | ||
| 1223 | } | ||
| 1224 | |||
| 1225 | switch (op->type) { | ||
| 1226 | case GGML_TYPE_F32: | ||
| 1227 | case GGML_TYPE_F16: | ||
| 1228 | case GGML_TYPE_BF16: | ||
| 1229 | case GGML_TYPE_Q8_0: | ||
| 1230 | case GGML_TYPE_Q4_0: | ||
| 1231 | case GGML_TYPE_Q4_1: | ||
| 1232 | case GGML_TYPE_Q5_0: | ||
| 1233 | case GGML_TYPE_Q5_1: | ||
| 1234 | case GGML_TYPE_IQ4_NL: | ||
| 1235 | return true; | ||
| 1236 | default: | ||
| 1237 | return false; | ||
| 1238 | }; | ||
| 1239 | } | ||
| 1240 | case GGML_OP_DIAG: | ||
| 1241 | return true; | ||
| 1242 | case GGML_OP_OPT_STEP_ADAMW: | ||
| 1243 | case GGML_OP_OPT_STEP_SGD: | ||
| 1244 | return has_simdgroup_reduction; | ||
| 1245 | default: | ||
| 1246 | return false; | ||
| 1247 | } | ||
| 1248 | } | ||
| 1249 | |||
| 1250 | const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev) { | ||
| 1251 | return &dev->props; | ||
| 1252 | } | ||
| 1253 | |||
| 1254 | // | ||
| 1255 | // device buffers | ||
| 1256 | // | ||
| 1257 | |||
| 1258 | // max memory buffers that can be mapped to the device | ||
| 1259 | #define GGML_METAL_MAX_BUFFERS 64 | ||
| 1260 | |||
| 1261 | struct ggml_metal_buffer_wrapper { | ||
| 1262 | void * data; | ||
| 1263 | size_t size; | ||
| 1264 | |||
| 1265 | id<MTLBuffer> metal; | ||
| 1266 | }; | ||
| 1267 | |||
| 1268 | struct ggml_metal_buffer { | ||
| 1269 | void * all_data; | ||
| 1270 | size_t all_size; | ||
| 1271 | |||
| 1272 | // if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host | ||
| 1273 | bool is_shared; | ||
| 1274 | bool owned; | ||
| 1275 | |||
| 1276 | // multiple buffers are used only to avoid the maximum buffer size limitation when using mmap | ||
| 1277 | int n_buffers; | ||
| 1278 | struct ggml_metal_buffer_wrapper buffers[GGML_METAL_MAX_BUFFERS]; | ||
| 1279 | |||
| 1280 | bool use_residency_sets; | ||
| 1281 | |||
| 1282 | // optional MTLResidencySet | ||
| 1283 | // note: cannot use explicity "id<MTLResidencySet>" here because it is not available on certain OSes | ||
| 1284 | id rset; | ||
| 1285 | |||
| 1286 | // pointers to global device | ||
| 1287 | ggml_metal_device_t dev; | ||
| 1288 | }; | ||
| 1289 | |||
| 1290 | static void ggml_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) { | ||
| 1291 | #ifndef GGML_METAL_NDEBUG | ||
| 1292 | #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) | ||
| 1293 | if (@available(macOS 10.12, iOS 16.0, *)) { | ||
| 1294 | GGML_LOG_DEBUG("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)\n", | ||
| 1295 | __func__, | ||
| 1296 | size_aligned / 1024.0 / 1024.0, | ||
| 1297 | device.currentAllocatedSize / 1024.0 / 1024.0, | ||
| 1298 | device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); | ||
| 1299 | |||
| 1300 | if (device.currentAllocatedSize > device.recommendedMaxWorkingSetSize) { | ||
| 1301 | GGML_LOG_WARN("%s: warning: current allocated size is greater than the recommended max working set size\n", __func__); | ||
| 1302 | } | ||
| 1303 | } else { | ||
| 1304 | GGML_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", | ||
| 1305 | __func__, | ||
| 1306 | size_aligned / 1024.0 / 1024.0, | ||
| 1307 | device.currentAllocatedSize / 1024.0 / 1024.0); | ||
| 1308 | } | ||
| 1309 | #endif | ||
| 1310 | #endif | ||
| 1311 | GGML_UNUSED(device); | ||
| 1312 | GGML_UNUSED(size_aligned); | ||
| 1313 | } | ||
| 1314 | |||
| 1315 | // rset init | ||
| 1316 | static bool ggml_metal_buffer_rset_init(ggml_metal_buffer_t buf) { | ||
| 1317 | buf->rset = nil; | ||
| 1318 | |||
| 1319 | if (!buf->use_residency_sets) { | ||
| 1320 | return true; | ||
| 1321 | } | ||
| 1322 | |||
| 1323 | #if defined(GGML_METAL_HAS_RESIDENCY_SETS) | ||
| 1324 | if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { | ||
| 1325 | MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init]; | ||
| 1326 | desc.label = @"ggml_metal"; | ||
| 1327 | desc.initialCapacity = buf->n_buffers; | ||
| 1328 | |||
| 1329 | NSError * error; | ||
| 1330 | buf->rset = [buf->dev->mtl_device newResidencySetWithDescriptor:desc error:&error]; | ||
| 1331 | if (error) { | ||
| 1332 | GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); | ||
| 1333 | [desc release]; | ||
| 1334 | return false; | ||
| 1335 | } | ||
| 1336 | |||
| 1337 | [desc release]; | ||
| 1338 | |||
| 1339 | for (int i = 0; i < buf->n_buffers; i++) { | ||
| 1340 | [buf->rset addAllocation:buf->buffers[i].metal]; | ||
| 1341 | } | ||
| 1342 | |||
| 1343 | [buf->rset commit]; | ||
| 1344 | [buf->rset requestResidency]; | ||
| 1345 | |||
| 1346 | return true; | ||
| 1347 | } | ||
| 1348 | #endif | ||
| 1349 | |||
| 1350 | return true; | ||
| 1351 | } | ||
| 1352 | |||
| 1353 | // rset free | ||
| 1354 | static void ggml_metal_buffer_rset_free(ggml_metal_buffer_t buf) { | ||
| 1355 | #if defined(GGML_METAL_HAS_RESIDENCY_SETS) | ||
| 1356 | if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) { | ||
| 1357 | if (buf->rset) { | ||
| 1358 | [buf->rset endResidency]; | ||
| 1359 | [buf->rset removeAllAllocations]; | ||
| 1360 | [buf->rset release]; | ||
| 1361 | } | ||
| 1362 | } | ||
| 1363 | #else | ||
| 1364 | GGML_UNUSED(buf); | ||
| 1365 | #endif | ||
| 1366 | } | ||
| 1367 | |||
| 1368 | static void * ggml_metal_host_malloc(size_t n) { | ||
| 1369 | void * data = NULL; | ||
| 1370 | |||
| 1371 | #if TARGET_OS_OSX | ||
| 1372 | kern_return_t err = vm_allocate((vm_map_t) mach_task_self(), (void *) &data, n, VM_FLAGS_ANYWHERE); | ||
| 1373 | if (err != KERN_SUCCESS) { | ||
| 1374 | GGML_LOG_ERROR("%s: error: vm_allocate failed\n", __func__); | ||
| 1375 | return NULL; | ||
| 1376 | } | ||
| 1377 | #else | ||
| 1378 | const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); | ||
| 1379 | if (result != 0) { | ||
| 1380 | GGML_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); | ||
| 1381 | return NULL; | ||
| 1382 | } | ||
| 1383 | #endif | ||
| 1384 | |||
| 1385 | return data; | ||
| 1386 | } | ||
| 1387 | |||
| 1388 | ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared) { | ||
| 1389 | ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); | ||
| 1390 | |||
| 1391 | res->dev = dev; | ||
| 1392 | |||
| 1393 | const size_t size_page = sysconf(_SC_PAGESIZE); | ||
| 1394 | |||
| 1395 | size_t size_aligned = size; | ||
| 1396 | if ((size_aligned % size_page) != 0) { | ||
| 1397 | size_aligned += (size_page - (size_aligned % size_page)); | ||
| 1398 | } | ||
| 1399 | |||
| 1400 | const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); | ||
| 1401 | |||
| 1402 | shared = shared && props_dev->use_shared_buffers; | ||
| 1403 | |||
| 1404 | // allocate shared buffer if the device supports it and it is required by the buffer type | ||
| 1405 | if (shared) { | ||
| 1406 | res->all_data = ggml_metal_host_malloc(size_aligned); | ||
| 1407 | res->is_shared = true; | ||
| 1408 | } else { | ||
| 1409 | // use virtual address | ||
| 1410 | res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed); | ||
| 1411 | res->is_shared = false; | ||
| 1412 | } | ||
| 1413 | res->all_size = size_aligned; | ||
| 1414 | |||
| 1415 | res->owned = true; | ||
| 1416 | |||
| 1417 | res->n_buffers = 1; | ||
| 1418 | |||
| 1419 | if (res->all_data != NULL) { | ||
| 1420 | res->buffers[0].size = size; | ||
| 1421 | res->buffers[0].metal = nil; | ||
| 1422 | |||
| 1423 | if (size_aligned > 0) { | ||
| 1424 | if (props_dev->use_shared_buffers && shared) { | ||
| 1425 | res->buffers[0].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:res->all_data | ||
| 1426 | length:size_aligned | ||
| 1427 | options:MTLResourceStorageModeShared | ||
| 1428 | deallocator:nil]; | ||
| 1429 | } else { | ||
| 1430 | res->buffers[0].metal = [res->dev->mtl_device newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate]; | ||
| 1431 | } | ||
| 1432 | } | ||
| 1433 | |||
| 1434 | res->buffers[0].data = res->all_data; | ||
| 1435 | } | ||
| 1436 | |||
| 1437 | if (size_aligned > 0 && (res->all_data == NULL || res->buffers[0].metal == nil)) { | ||
| 1438 | GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); | ||
| 1439 | free(res); | ||
| 1440 | return NULL; | ||
| 1441 | } | ||
| 1442 | |||
| 1443 | res->use_residency_sets = props_dev->use_residency_sets; | ||
| 1444 | |||
| 1445 | if (!ggml_metal_buffer_rset_init(res)) { | ||
| 1446 | GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); | ||
| 1447 | free(res); | ||
| 1448 | return NULL; | ||
| 1449 | } | ||
| 1450 | |||
| 1451 | ggml_metal_device_rsets_add(dev, res->rset); | ||
| 1452 | |||
| 1453 | //ggml_metal_log_allocated_size(device, size_aligned); | ||
| 1454 | |||
| 1455 | return res; | ||
| 1456 | } | ||
| 1457 | |||
| 1458 | ggml_metal_buffer_t ggml_metal_buffer_map(ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size) { | ||
| 1459 | ggml_metal_buffer_t res = calloc(1, sizeof(struct ggml_metal_buffer)); | ||
| 1460 | |||
| 1461 | res->dev = dev; | ||
| 1462 | |||
| 1463 | res->all_data = ptr; | ||
| 1464 | res->all_size = size; | ||
| 1465 | |||
| 1466 | res->is_shared = true; | ||
| 1467 | res->owned = false; | ||
| 1468 | |||
| 1469 | res->n_buffers = 0; | ||
| 1470 | |||
| 1471 | const size_t size_page = sysconf(_SC_PAGESIZE); | ||
| 1472 | |||
| 1473 | // page-align the data ptr | ||
| 1474 | { | ||
| 1475 | const uintptr_t offs = (uintptr_t) ptr % size_page; | ||
| 1476 | ptr = (void *) ((char *) ptr - offs); | ||
| 1477 | size += offs; | ||
| 1478 | } | ||
| 1479 | |||
| 1480 | size_t size_aligned = size; | ||
| 1481 | if ((size_aligned % size_page) != 0) { | ||
| 1482 | size_aligned += (size_page - (size_aligned % size_page)); | ||
| 1483 | } | ||
| 1484 | |||
| 1485 | const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); | ||
| 1486 | |||
| 1487 | // the buffer fits into the max buffer size allowed by the device | ||
| 1488 | if (size_aligned <= props_dev->max_buffer_size) { | ||
| 1489 | res->buffers[res->n_buffers].data = ptr; | ||
| 1490 | res->buffers[res->n_buffers].size = size; | ||
| 1491 | res->buffers[res->n_buffers].metal = nil; | ||
| 1492 | |||
| 1493 | if (size_aligned > 0) { | ||
| 1494 | res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; | ||
| 1495 | |||
| 1496 | if (res->buffers[res->n_buffers].metal == nil) { | ||
| 1497 | GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0); | ||
| 1498 | free(res); | ||
| 1499 | return NULL; | ||
| 1500 | } | ||
| 1501 | } | ||
| 1502 | |||
| 1503 | ggml_metal_log_allocated_size(res->dev->mtl_device, size_aligned); | ||
| 1504 | |||
| 1505 | ++res->n_buffers; | ||
| 1506 | } else { | ||
| 1507 | // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into | ||
| 1508 | // one of the views | ||
| 1509 | const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case | ||
| 1510 | const size_t size_step = props_dev->max_buffer_size - size_ovlp; | ||
| 1511 | const size_t size_view = props_dev->max_buffer_size; | ||
| 1512 | |||
| 1513 | for (size_t i = 0; i < size; i += size_step) { | ||
| 1514 | const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i); | ||
| 1515 | |||
| 1516 | res->buffers[res->n_buffers].data = (void *) ((uint8_t *) ptr + i); | ||
| 1517 | res->buffers[res->n_buffers].size = size_step_aligned; | ||
| 1518 | res->buffers[res->n_buffers].metal = nil; | ||
| 1519 | |||
| 1520 | if (size_step_aligned > 0) { | ||
| 1521 | res->buffers[res->n_buffers].metal = [res->dev->mtl_device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; | ||
| 1522 | |||
| 1523 | if (res->buffers[res->n_buffers].metal == nil) { | ||
| 1524 | GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0); | ||
| 1525 | free(res); | ||
| 1526 | return NULL; | ||
| 1527 | } | ||
| 1528 | } | ||
| 1529 | |||
| 1530 | ggml_metal_log_allocated_size(res->dev->mtl_device, size_step_aligned); | ||
| 1531 | |||
| 1532 | if (i + size_step < size) { | ||
| 1533 | GGML_LOG_INFO("\n"); | ||
| 1534 | } | ||
| 1535 | |||
| 1536 | ++res->n_buffers; | ||
| 1537 | } | ||
| 1538 | } | ||
| 1539 | |||
| 1540 | res->use_residency_sets = props_dev->use_residency_sets; | ||
| 1541 | |||
| 1542 | if (!ggml_metal_buffer_rset_init(res)) { | ||
| 1543 | GGML_LOG_ERROR("%s: error: failed to initialize residency set\n", __func__); | ||
| 1544 | free(res); | ||
| 1545 | return NULL; | ||
| 1546 | } | ||
| 1547 | |||
| 1548 | ggml_metal_device_rsets_add(dev, res->rset); | ||
| 1549 | |||
| 1550 | return res; | ||
| 1551 | } | ||
| 1552 | |||
| 1553 | void ggml_metal_buffer_free(ggml_metal_buffer_t buf) { | ||
| 1554 | ggml_metal_device_rsets_rm(buf->dev, buf->rset); | ||
| 1555 | |||
| 1556 | for (int i = 0; i < buf->n_buffers; i++) { | ||
| 1557 | [buf->buffers[i].metal release]; | ||
| 1558 | } | ||
| 1559 | |||
| 1560 | ggml_metal_buffer_rset_free(buf); | ||
| 1561 | |||
| 1562 | if (buf->is_shared && buf->owned) { | ||
| 1563 | #if TARGET_OS_OSX | ||
| 1564 | vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)buf->all_data, buf->all_size); | ||
| 1565 | #else | ||
| 1566 | free(buf->all_data); | ||
| 1567 | #endif | ||
| 1568 | } | ||
| 1569 | |||
| 1570 | free(buf); | ||
| 1571 | } | ||
| 1572 | |||
| 1573 | void * ggml_metal_buffer_get_base(ggml_metal_buffer_t buf) { | ||
| 1574 | return buf->all_data; | ||
| 1575 | } | ||
| 1576 | |||
| 1577 | bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) { | ||
| 1578 | return buf->is_shared; | ||
| 1579 | } | ||
| 1580 | |||
| 1581 | void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { | ||
| 1582 | if (buf->is_shared) { | ||
| 1583 | memset((char *) tensor->data + offset, value, size); | ||
| 1584 | return; | ||
| 1585 | } | ||
| 1586 | |||
| 1587 | @autoreleasepool { | ||
| 1588 | // dst | ||
| 1589 | struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); | ||
| 1590 | bid_dst.offs += offset; | ||
| 1591 | |||
| 1592 | id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; | ||
| 1593 | |||
| 1594 | { | ||
| 1595 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 1596 | |||
| 1597 | [encoder fillBuffer:bid_dst.metal | ||
| 1598 | range:NSMakeRange(bid_dst.offs, bid_dst.offs + size) | ||
| 1599 | value:value]; | ||
| 1600 | |||
| 1601 | [encoder endEncoding]; | ||
| 1602 | } | ||
| 1603 | |||
| 1604 | [cmd_buf commit]; | ||
| 1605 | [cmd_buf waitUntilCompleted]; | ||
| 1606 | } | ||
| 1607 | } | ||
| 1608 | |||
| 1609 | void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { | ||
| 1610 | if (buf->is_shared) { | ||
| 1611 | memcpy((char *) tensor->data + offset, data, size); | ||
| 1612 | return; | ||
| 1613 | } | ||
| 1614 | |||
| 1615 | @autoreleasepool { | ||
| 1616 | // src | ||
| 1617 | void * data_ptr = (void *)(uintptr_t) data; // "const cast" the src data | ||
| 1618 | id<MTLBuffer> buf_src = [buf->dev->mtl_device newBufferWithBytesNoCopy:data_ptr | ||
| 1619 | length:size | ||
| 1620 | options:MTLResourceStorageModeShared | ||
| 1621 | deallocator:nil]; | ||
| 1622 | |||
| 1623 | GGML_ASSERT(buf_src); | ||
| 1624 | |||
| 1625 | // dst | ||
| 1626 | struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf, tensor); | ||
| 1627 | bid_dst.offs += offset; | ||
| 1628 | |||
| 1629 | // note: for experimentation purposes, here we use a semaphore to wait for the copy to complete | ||
| 1630 | // this is alternative to waitUntilCompleted, which should be faster, but don't seem to make much difference | ||
| 1631 | dispatch_semaphore_t completion_semaphore = dispatch_semaphore_create(0); | ||
| 1632 | |||
| 1633 | id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; | ||
| 1634 | |||
| 1635 | { | ||
| 1636 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 1637 | |||
| 1638 | [encoder copyFromBuffer:buf_src | ||
| 1639 | sourceOffset:0 | ||
| 1640 | toBuffer:bid_dst.metal | ||
| 1641 | destinationOffset:bid_dst.offs | ||
| 1642 | size:size]; | ||
| 1643 | |||
| 1644 | [encoder endEncoding]; | ||
| 1645 | } | ||
| 1646 | |||
| 1647 | [cmd_buf addCompletedHandler:^(id<MTLCommandBuffer> cb) { | ||
| 1648 | // TODO: can check for errors here | ||
| 1649 | GGML_UNUSED(cb); | ||
| 1650 | |||
| 1651 | dispatch_semaphore_signal(completion_semaphore); | ||
| 1652 | }]; | ||
| 1653 | |||
| 1654 | [cmd_buf commit]; | ||
| 1655 | |||
| 1656 | dispatch_semaphore_wait(completion_semaphore, DISPATCH_TIME_FOREVER); | ||
| 1657 | dispatch_release(completion_semaphore); | ||
| 1658 | |||
| 1659 | //[cmd_buf waitUntilCompleted]; | ||
| 1660 | } | ||
| 1661 | } | ||
| 1662 | |||
| 1663 | void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { | ||
| 1664 | if (buf->is_shared) { | ||
| 1665 | memcpy(data, (const char *) tensor->data + offset, size); | ||
| 1666 | return; | ||
| 1667 | } | ||
| 1668 | |||
| 1669 | @autoreleasepool { | ||
| 1670 | // src | ||
| 1671 | struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf, tensor); | ||
| 1672 | bid_src.offs += offset; | ||
| 1673 | |||
| 1674 | // dst | ||
| 1675 | id<MTLBuffer> buf_dst = [buf->dev->mtl_device newBufferWithBytesNoCopy:data | ||
| 1676 | length:size | ||
| 1677 | options:MTLResourceStorageModeShared | ||
| 1678 | deallocator:nil]; | ||
| 1679 | |||
| 1680 | GGML_ASSERT(buf_dst); | ||
| 1681 | |||
| 1682 | id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; | ||
| 1683 | |||
| 1684 | { | ||
| 1685 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 1686 | |||
| 1687 | [encoder copyFromBuffer:bid_src.metal | ||
| 1688 | sourceOffset:bid_src.offs | ||
| 1689 | toBuffer:buf_dst | ||
| 1690 | destinationOffset:0 | ||
| 1691 | size:size]; | ||
| 1692 | |||
| 1693 | [encoder endEncoding]; | ||
| 1694 | } | ||
| 1695 | |||
| 1696 | [cmd_buf commit]; | ||
| 1697 | [cmd_buf waitUntilCompleted]; | ||
| 1698 | } | ||
| 1699 | } | ||
| 1700 | |||
| 1701 | void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) { | ||
| 1702 | if (buf->is_shared) { | ||
| 1703 | memset(buf->all_data, value, buf->all_size); | ||
| 1704 | return; | ||
| 1705 | } | ||
| 1706 | |||
| 1707 | @autoreleasepool { | ||
| 1708 | id<MTLCommandBuffer> cmd_buf = [buf->dev->mtl_queue commandBufferWithUnretainedReferences]; | ||
| 1709 | |||
| 1710 | { | ||
| 1711 | id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; | ||
| 1712 | |||
| 1713 | [encoder fillBuffer:buf->buffers[0].metal | ||
| 1714 | range:NSMakeRange(0, buf->buffers[0].size) | ||
| 1715 | value:value]; | ||
| 1716 | |||
| 1717 | [encoder endEncoding]; | ||
| 1718 | } | ||
| 1719 | |||
| 1720 | [cmd_buf commit]; | ||
| 1721 | [cmd_buf waitUntilCompleted]; | ||
| 1722 | } | ||
| 1723 | } | ||
| 1724 | |||
| 1725 | struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t) { | ||
| 1726 | struct ggml_metal_buffer_id res = { nil, 0 }; | ||
| 1727 | |||
| 1728 | const int64_t tsize = ggml_nbytes(t); | ||
| 1729 | |||
| 1730 | // find the view that contains the tensor fully | ||
| 1731 | for (int i = 0; i < buf->n_buffers; ++i) { | ||
| 1732 | const int64_t ioffs = (int64_t) t->data - (int64_t) buf->buffers[i].data; | ||
| 1733 | |||
| 1734 | //GGML_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, buf->buffers[%d].size = %10ld\n", ioffs, tsize, ioffs + tsize, i, buf->buffers[i].size); | ||
| 1735 | if (ioffs >= 0 && ioffs + tsize <= (int64_t) buf->buffers[i].size) { | ||
| 1736 | res.metal = buf->buffers[i].metal; | ||
| 1737 | res.offs = (size_t) ioffs; | ||
| 1738 | |||
| 1739 | //GGML_LOG_INFO("%s: tensor '%16s', offs = %8ld\n", __func__, t->name, *offs); | ||
| 1740 | |||
| 1741 | return res; | ||
| 1742 | } | ||
| 1743 | } | ||
| 1744 | |||
| 1745 | GGML_LOG_ERROR("%s: error: tensor '%s' buffer is nil\n", __func__, t->name); | ||
| 1746 | |||
| 1747 | return res; | ||
| 1748 | } | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h new file mode 100644 index 0000000..952e1be --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h | |||
| @@ -0,0 +1,1051 @@ | |||
| 1 | #ifndef GGML_METAL_IMPL | ||
| 2 | #define GGML_METAL_IMPL | ||
| 3 | |||
| 4 | // kernel parameters for mat-vec threadgroups | ||
| 5 | // | ||
| 6 | // N_R0: number of src0 rows to process per simdgroup | ||
| 7 | // N_SG: number of simdgroups per threadgroup | ||
| 8 | // | ||
| 9 | // TODO: for optimal performance, become function of the device and work size | ||
| 10 | |||
| 11 | #define N_R0_Q4_0 4 | ||
| 12 | #define N_SG_Q4_0 2 | ||
| 13 | |||
| 14 | #define N_R0_Q4_1 4 | ||
| 15 | #define N_SG_Q4_1 2 | ||
| 16 | |||
| 17 | #define N_R0_Q5_0 4 | ||
| 18 | #define N_SG_Q5_0 2 | ||
| 19 | |||
| 20 | #define N_R0_Q5_1 4 | ||
| 21 | #define N_SG_Q5_1 2 | ||
| 22 | |||
| 23 | #define N_R0_Q8_0 2 | ||
| 24 | #define N_SG_Q8_0 4 | ||
| 25 | |||
| 26 | #define N_R0_MXFP4 2 | ||
| 27 | #define N_SG_MXFP4 2 | ||
| 28 | |||
| 29 | #define N_R0_Q2_K 4 | ||
| 30 | #define N_SG_Q2_K 2 | ||
| 31 | |||
| 32 | #define N_R0_Q3_K 2 | ||
| 33 | #define N_SG_Q3_K 2 | ||
| 34 | |||
| 35 | #define N_R0_Q4_K 2 | ||
| 36 | #define N_SG_Q4_K 2 | ||
| 37 | |||
| 38 | #define N_R0_Q5_K 2 | ||
| 39 | #define N_SG_Q5_K 2 | ||
| 40 | |||
| 41 | #define N_R0_Q6_K 2 | ||
| 42 | #define N_SG_Q6_K 2 | ||
| 43 | |||
| 44 | #define N_R0_IQ1_S 4 | ||
| 45 | #define N_SG_IQ1_S 2 | ||
| 46 | |||
| 47 | #define N_R0_IQ1_M 4 | ||
| 48 | #define N_SG_IQ1_M 2 | ||
| 49 | |||
| 50 | #define N_R0_IQ2_XXS 4 | ||
| 51 | #define N_SG_IQ2_XXS 2 | ||
| 52 | |||
| 53 | #define N_R0_IQ2_XS 4 | ||
| 54 | #define N_SG_IQ2_XS 2 | ||
| 55 | |||
| 56 | #define N_R0_IQ2_S 4 | ||
| 57 | #define N_SG_IQ2_S 2 | ||
| 58 | |||
| 59 | #define N_R0_IQ3_XXS 4 | ||
| 60 | #define N_SG_IQ3_XXS 2 | ||
| 61 | |||
| 62 | #define N_R0_IQ3_S 4 | ||
| 63 | #define N_SG_IQ3_S 2 | ||
| 64 | |||
| 65 | #define N_R0_IQ4_NL 2 | ||
| 66 | #define N_SG_IQ4_NL 2 | ||
| 67 | |||
| 68 | #define N_R0_IQ4_XS 2 | ||
| 69 | #define N_SG_IQ4_XS 2 | ||
| 70 | |||
| 71 | // function constants offsets | ||
| 72 | #define FC_FLASH_ATTN_EXT_PAD 100 | ||
| 73 | #define FC_FLASH_ATTN_EXT_BLK 200 | ||
| 74 | #define FC_FLASH_ATTN_EXT 300 | ||
| 75 | #define FC_FLASH_ATTN_EXT_VEC 400 | ||
| 76 | #define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 | ||
| 77 | #define FC_MUL_MV 600 | ||
| 78 | #define FC_MUL_MM 700 | ||
| 79 | #define FC_ROPE 800 | ||
| 80 | #define FC_SSM_CONV 900 | ||
| 81 | #define FC_SOLVE_TRI 1000 | ||
| 82 | #define FC_COUNT_EQUAL 1100 | ||
| 83 | #define FC_UNARY 1200 | ||
| 84 | #define FC_BIN 1300 | ||
| 85 | |||
| 86 | // op-specific constants | ||
| 87 | #define OP_FLASH_ATTN_EXT_NQPSG 8 | ||
| 88 | #define OP_FLASH_ATTN_EXT_NCPSG 64 | ||
| 89 | |||
| 90 | #define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 | ||
| 91 | #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 | ||
| 92 | |||
| 93 | #define OP_UNARY_NUM_SCALE 10 | ||
| 94 | #define OP_UNARY_NUM_FILL 11 | ||
| 95 | #define OP_UNARY_NUM_CLAMP 12 | ||
| 96 | #define OP_UNARY_NUM_SQR 13 | ||
| 97 | #define OP_UNARY_NUM_SQRT 14 | ||
| 98 | #define OP_UNARY_NUM_SIN 15 | ||
| 99 | #define OP_UNARY_NUM_COS 16 | ||
| 100 | #define OP_UNARY_NUM_LOG 17 | ||
| 101 | #define OP_UNARY_NUM_LEAKY_RELU 18 | ||
| 102 | |||
| 103 | #define OP_UNARY_NUM_TANH 100 | ||
| 104 | #define OP_UNARY_NUM_RELU 101 | ||
| 105 | #define OP_UNARY_NUM_SIGMOID 102 | ||
| 106 | #define OP_UNARY_NUM_GELU 103 | ||
| 107 | #define OP_UNARY_NUM_GELU_ERF 104 | ||
| 108 | #define OP_UNARY_NUM_GELU_QUICK 105 | ||
| 109 | #define OP_UNARY_NUM_SILU 106 | ||
| 110 | #define OP_UNARY_NUM_ELU 107 | ||
| 111 | #define OP_UNARY_NUM_NEG 108 | ||
| 112 | #define OP_UNARY_NUM_ABS 109 | ||
| 113 | #define OP_UNARY_NUM_SGN 110 | ||
| 114 | #define OP_UNARY_NUM_STEP 111 | ||
| 115 | #define OP_UNARY_NUM_HARDSWISH 112 | ||
| 116 | #define OP_UNARY_NUM_HARDSIGMOID 113 | ||
| 117 | #define OP_UNARY_NUM_EXP 114 | ||
| 118 | #define OP_UNARY_NUM_SOFTPLUS 115 | ||
| 119 | #define OP_UNARY_NUM_EXPM1 116 | ||
| 120 | |||
| 121 | |||
| 122 | // kernel argument structs | ||
| 123 | // | ||
| 124 | // - element counters (e.g. ne00) typically use int32_t to reduce register usage | ||
| 125 | // however, be careful from int overflows when using those in the kernel implementation | ||
| 126 | // | ||
| 127 | // - strides (e.g. nb00) use uint64_t | ||
| 128 | |||
| 129 | typedef struct { | ||
| 130 | int32_t ne00; | ||
| 131 | int32_t ne01; | ||
| 132 | int32_t ne02; | ||
| 133 | int32_t ne03; | ||
| 134 | uint64_t nb00; | ||
| 135 | uint64_t nb01; | ||
| 136 | uint64_t nb02; | ||
| 137 | uint64_t nb03; | ||
| 138 | int32_t ne10; | ||
| 139 | int32_t ne11; | ||
| 140 | int32_t ne12; | ||
| 141 | int32_t ne13; | ||
| 142 | uint64_t nb10; | ||
| 143 | uint64_t nb11; | ||
| 144 | uint64_t nb12; | ||
| 145 | uint64_t nb13; | ||
| 146 | int32_t ne0; | ||
| 147 | int32_t ne1; | ||
| 148 | int32_t ne2; | ||
| 149 | int32_t ne3; | ||
| 150 | uint64_t nb0; | ||
| 151 | uint64_t nb1; | ||
| 152 | uint64_t nb2; | ||
| 153 | uint64_t nb3; | ||
| 154 | int32_t dim; | ||
| 155 | } ggml_metal_kargs_concat; | ||
| 156 | |||
| 157 | typedef struct { | ||
| 158 | int32_t ne00; | ||
| 159 | int32_t ne01; | ||
| 160 | int32_t ne02; | ||
| 161 | int32_t ne03; | ||
| 162 | uint64_t nb00; | ||
| 163 | uint64_t nb01; | ||
| 164 | uint64_t nb02; | ||
| 165 | uint64_t nb03; | ||
| 166 | int32_t ne0; | ||
| 167 | int32_t ne1; | ||
| 168 | int32_t ne2; | ||
| 169 | int32_t ne3; | ||
| 170 | uint64_t nb0; | ||
| 171 | uint64_t nb1; | ||
| 172 | uint64_t nb2; | ||
| 173 | uint64_t nb3; | ||
| 174 | float slope; | ||
| 175 | float scale; | ||
| 176 | float bias; | ||
| 177 | float val; | ||
| 178 | float min; | ||
| 179 | float max; | ||
| 180 | } ggml_metal_kargs_unary; | ||
| 181 | |||
| 182 | typedef struct { | ||
| 183 | int32_t ne00; | ||
| 184 | int32_t ne01; | ||
| 185 | int32_t ne02; | ||
| 186 | int32_t ne03; | ||
| 187 | uint64_t nb00; | ||
| 188 | uint64_t nb01; | ||
| 189 | uint64_t nb02; | ||
| 190 | uint64_t nb03; | ||
| 191 | int32_t ne10; | ||
| 192 | int32_t ne11; | ||
| 193 | int32_t ne12; | ||
| 194 | int32_t ne13; | ||
| 195 | uint64_t nb10; | ||
| 196 | uint64_t nb11; | ||
| 197 | uint64_t nb12; | ||
| 198 | uint64_t nb13; | ||
| 199 | int32_t ne0; | ||
| 200 | int32_t ne1; | ||
| 201 | int32_t ne2; | ||
| 202 | int32_t ne3; | ||
| 203 | uint64_t nb0; | ||
| 204 | uint64_t nb1; | ||
| 205 | uint64_t nb2; | ||
| 206 | uint64_t nb3; | ||
| 207 | uint64_t offs; | ||
| 208 | uint64_t o1[8]; | ||
| 209 | } ggml_metal_kargs_bin; | ||
| 210 | |||
| 211 | typedef struct { | ||
| 212 | int64_t ne0; | ||
| 213 | int64_t ne1; | ||
| 214 | size_t nb01; | ||
| 215 | size_t nb02; | ||
| 216 | size_t nb11; | ||
| 217 | size_t nb21; | ||
| 218 | } ggml_metal_kargs_add_id; | ||
| 219 | |||
| 220 | typedef struct { | ||
| 221 | int32_t ne00; | ||
| 222 | int32_t ne01; | ||
| 223 | int32_t ne02; | ||
| 224 | int32_t ne03; | ||
| 225 | uint64_t nb00; | ||
| 226 | uint64_t nb01; | ||
| 227 | uint64_t nb02; | ||
| 228 | uint64_t nb03; | ||
| 229 | int32_t ne0; | ||
| 230 | int32_t ne1; | ||
| 231 | int32_t ne2; | ||
| 232 | int32_t ne3; | ||
| 233 | uint64_t nb0; | ||
| 234 | uint64_t nb1; | ||
| 235 | uint64_t nb2; | ||
| 236 | uint64_t nb3; | ||
| 237 | } ggml_metal_kargs_repeat; | ||
| 238 | |||
| 239 | typedef struct { | ||
| 240 | int64_t nk0; | ||
| 241 | int64_t ne00; | ||
| 242 | int64_t ne01; | ||
| 243 | int64_t ne02; | ||
| 244 | int64_t ne03; | ||
| 245 | uint64_t nb00; | ||
| 246 | uint64_t nb01; | ||
| 247 | uint64_t nb02; | ||
| 248 | uint64_t nb03; | ||
| 249 | int64_t ne0; | ||
| 250 | int64_t ne1; | ||
| 251 | int64_t ne2; | ||
| 252 | int64_t ne3; | ||
| 253 | uint64_t nb0; | ||
| 254 | uint64_t nb1; | ||
| 255 | uint64_t nb2; | ||
| 256 | uint64_t nb3; | ||
| 257 | } ggml_metal_kargs_cpy; | ||
| 258 | |||
| 259 | typedef struct { | ||
| 260 | int64_t ne10; | ||
| 261 | int64_t ne11; | ||
| 262 | int64_t ne12; | ||
| 263 | uint64_t nb10; | ||
| 264 | uint64_t nb11; | ||
| 265 | uint64_t nb12; | ||
| 266 | uint64_t nb13; | ||
| 267 | uint64_t nb1; | ||
| 268 | uint64_t nb2; | ||
| 269 | uint64_t nb3; | ||
| 270 | uint64_t offs; | ||
| 271 | bool inplace; | ||
| 272 | } ggml_metal_kargs_set; | ||
| 273 | |||
| 274 | typedef struct { | ||
| 275 | int32_t ne00; | ||
| 276 | int32_t ne01; | ||
| 277 | int32_t ne02; | ||
| 278 | int32_t ne03; | ||
| 279 | uint64_t nb00; | ||
| 280 | uint64_t nb01; | ||
| 281 | uint64_t nb02; | ||
| 282 | uint64_t nb03; | ||
| 283 | int32_t ne0; | ||
| 284 | int32_t ne1; | ||
| 285 | int32_t ne2; | ||
| 286 | int32_t ne3; | ||
| 287 | uint64_t nb0; | ||
| 288 | uint64_t nb1; | ||
| 289 | uint64_t nb2; | ||
| 290 | uint64_t nb3; | ||
| 291 | int32_t n_past; | ||
| 292 | int32_t n_dims; | ||
| 293 | int32_t n_ctx_orig; | ||
| 294 | float freq_base; | ||
| 295 | float freq_scale; | ||
| 296 | float ext_factor; | ||
| 297 | float attn_factor; | ||
| 298 | float beta_fast; | ||
| 299 | float beta_slow; | ||
| 300 | int32_t sect_0; | ||
| 301 | int32_t sect_1; | ||
| 302 | int32_t sect_2; | ||
| 303 | int32_t sect_3; | ||
| 304 | bool src2; | ||
| 305 | } ggml_metal_kargs_rope; | ||
| 306 | |||
| 307 | typedef struct { | ||
| 308 | int32_t ne11; | ||
| 309 | int32_t ne_12_2; // assume K and V are same shape | ||
| 310 | int32_t ne_12_3; | ||
| 311 | uint64_t nb11; | ||
| 312 | uint64_t nb12; | ||
| 313 | uint64_t nb13; | ||
| 314 | uint64_t nb21; | ||
| 315 | uint64_t nb22; | ||
| 316 | uint64_t nb23; | ||
| 317 | int32_t ne31; | ||
| 318 | int32_t ne32; | ||
| 319 | int32_t ne33; | ||
| 320 | uint64_t nb31; | ||
| 321 | uint64_t nb32; | ||
| 322 | uint64_t nb33; | ||
| 323 | } ggml_metal_kargs_flash_attn_ext_pad; | ||
| 324 | |||
| 325 | typedef struct { | ||
| 326 | int32_t ne01; | ||
| 327 | int32_t ne30; | ||
| 328 | int32_t ne31; | ||
| 329 | int32_t ne32; | ||
| 330 | int32_t ne33; | ||
| 331 | uint64_t nb31; | ||
| 332 | uint64_t nb32; | ||
| 333 | uint64_t nb33; | ||
| 334 | } ggml_metal_kargs_flash_attn_ext_blk; | ||
| 335 | |||
| 336 | typedef struct { | ||
| 337 | int32_t ne01; | ||
| 338 | int32_t ne02; | ||
| 339 | int32_t ne03; | ||
| 340 | uint64_t nb01; | ||
| 341 | uint64_t nb02; | ||
| 342 | uint64_t nb03; | ||
| 343 | int32_t ne11; | ||
| 344 | int32_t ne_12_2; // assume K and V are same shape | ||
| 345 | int32_t ne_12_3; | ||
| 346 | int32_t ns10; | ||
| 347 | uint64_t nb11; | ||
| 348 | uint64_t nb12; | ||
| 349 | uint64_t nb13; | ||
| 350 | int32_t ns20; | ||
| 351 | uint64_t nb21; | ||
| 352 | uint64_t nb22; | ||
| 353 | uint64_t nb23; | ||
| 354 | int32_t ne31; | ||
| 355 | int32_t ne32; | ||
| 356 | int32_t ne33; | ||
| 357 | uint64_t nb31; | ||
| 358 | uint64_t nb32; | ||
| 359 | uint64_t nb33; | ||
| 360 | int32_t ne1; | ||
| 361 | int32_t ne2; | ||
| 362 | int32_t ne3; | ||
| 363 | float scale; | ||
| 364 | float max_bias; | ||
| 365 | float m0; | ||
| 366 | float m1; | ||
| 367 | int32_t n_head_log2; | ||
| 368 | float logit_softcap; | ||
| 369 | } ggml_metal_kargs_flash_attn_ext; | ||
| 370 | |||
| 371 | typedef struct { | ||
| 372 | int32_t ne01; | ||
| 373 | int32_t ne02; | ||
| 374 | int32_t ne03; | ||
| 375 | uint64_t nb01; | ||
| 376 | uint64_t nb02; | ||
| 377 | uint64_t nb03; | ||
| 378 | int32_t ne11; | ||
| 379 | int32_t ne_12_2; // assume K and V are same shape | ||
| 380 | int32_t ne_12_3; | ||
| 381 | int32_t ns10; | ||
| 382 | uint64_t nb11; | ||
| 383 | uint64_t nb12; | ||
| 384 | uint64_t nb13; | ||
| 385 | int32_t ns20; | ||
| 386 | uint64_t nb21; | ||
| 387 | uint64_t nb22; | ||
| 388 | uint64_t nb23; | ||
| 389 | int32_t ne31; | ||
| 390 | int32_t ne32; | ||
| 391 | int32_t ne33; | ||
| 392 | uint64_t nb31; | ||
| 393 | uint64_t nb32; | ||
| 394 | uint64_t nb33; | ||
| 395 | int32_t ne1; | ||
| 396 | int32_t ne2; | ||
| 397 | int32_t ne3; | ||
| 398 | float scale; | ||
| 399 | float max_bias; | ||
| 400 | float m0; | ||
| 401 | float m1; | ||
| 402 | int32_t n_head_log2; | ||
| 403 | float logit_softcap; | ||
| 404 | } ggml_metal_kargs_flash_attn_ext_vec; | ||
| 405 | |||
| 406 | typedef struct { | ||
| 407 | int32_t nrows; | ||
| 408 | } ggml_metal_kargs_flash_attn_ext_vec_reduce; | ||
| 409 | |||
| 410 | typedef struct { | ||
| 411 | int32_t ne00; | ||
| 412 | int32_t ne02; | ||
| 413 | uint64_t nb01; | ||
| 414 | uint64_t nb02; | ||
| 415 | uint64_t nb03; | ||
| 416 | int32_t ne12; | ||
| 417 | uint64_t nb10; | ||
| 418 | uint64_t nb11; | ||
| 419 | uint64_t nb12; | ||
| 420 | uint64_t nb13; | ||
| 421 | int32_t ne0; | ||
| 422 | int32_t ne1; | ||
| 423 | int16_t r2; | ||
| 424 | int16_t r3; | ||
| 425 | } ggml_metal_kargs_mul_mm; | ||
| 426 | |||
| 427 | typedef struct { | ||
| 428 | int32_t ne00; | ||
| 429 | int32_t ne01; | ||
| 430 | int32_t ne02; | ||
| 431 | uint64_t nb00; | ||
| 432 | uint64_t nb01; | ||
| 433 | uint64_t nb02; | ||
| 434 | uint64_t nb03; | ||
| 435 | int32_t ne10; | ||
| 436 | int32_t ne11; | ||
| 437 | int32_t ne12; | ||
| 438 | uint64_t nb10; | ||
| 439 | uint64_t nb11; | ||
| 440 | uint64_t nb12; | ||
| 441 | uint64_t nb13; | ||
| 442 | int32_t ne0; | ||
| 443 | int32_t ne1; | ||
| 444 | int32_t nr0; | ||
| 445 | int16_t r2; | ||
| 446 | int16_t r3; | ||
| 447 | } ggml_metal_kargs_mul_mv; | ||
| 448 | |||
| 449 | typedef struct { | ||
| 450 | int32_t ne00; | ||
| 451 | int32_t ne01; | ||
| 452 | int32_t ne02; | ||
| 453 | uint64_t nb00; | ||
| 454 | uint64_t nb01; | ||
| 455 | uint64_t nb02; | ||
| 456 | uint64_t nb03; | ||
| 457 | int32_t ne10; | ||
| 458 | int32_t ne11; | ||
| 459 | int32_t ne12; | ||
| 460 | uint64_t nb10; | ||
| 461 | uint64_t nb11; | ||
| 462 | uint64_t nb12; | ||
| 463 | uint64_t nb13; | ||
| 464 | int32_t ne0; | ||
| 465 | int32_t ne1; | ||
| 466 | int16_t r2; | ||
| 467 | int16_t r3; | ||
| 468 | } ggml_metal_kargs_mul_mv_ext; | ||
| 469 | |||
| 470 | typedef struct { | ||
| 471 | int32_t ne02; | ||
| 472 | int32_t ne10; | ||
| 473 | int32_t ne11; // n_expert_used (bcast) | ||
| 474 | uint64_t nb11; | ||
| 475 | uint64_t nb12; | ||
| 476 | int32_t ne21; // n_tokens | ||
| 477 | int32_t ne20; // n_expert_used | ||
| 478 | uint64_t nb21; | ||
| 479 | } ggml_metal_kargs_mul_mm_id_map0; | ||
| 480 | |||
| 481 | typedef struct { | ||
| 482 | int32_t ne00; | ||
| 483 | int32_t ne02; | ||
| 484 | uint64_t nb01; | ||
| 485 | uint64_t nb02; | ||
| 486 | uint64_t nb03; | ||
| 487 | int32_t ne11; | ||
| 488 | uint64_t nb10; | ||
| 489 | uint64_t nb11; | ||
| 490 | uint64_t nb12; | ||
| 491 | uint64_t nb13; | ||
| 492 | int32_t ne20; | ||
| 493 | int32_t ne21; | ||
| 494 | int32_t ne0; | ||
| 495 | int32_t ne1; | ||
| 496 | int16_t r2; | ||
| 497 | int16_t r3; | ||
| 498 | } ggml_metal_kargs_mul_mm_id; | ||
| 499 | |||
| 500 | typedef struct { | ||
| 501 | int32_t nei0; | ||
| 502 | int32_t nei1; | ||
| 503 | uint64_t nbi1; | ||
| 504 | int32_t ne00; | ||
| 505 | int32_t ne01; | ||
| 506 | int32_t ne02; | ||
| 507 | uint64_t nb00; | ||
| 508 | uint64_t nb01; | ||
| 509 | uint64_t nb02; | ||
| 510 | int32_t ne10; | ||
| 511 | int32_t ne11; | ||
| 512 | int32_t ne12; | ||
| 513 | int32_t ne13; | ||
| 514 | uint64_t nb10; | ||
| 515 | uint64_t nb11; | ||
| 516 | uint64_t nb12; | ||
| 517 | int32_t ne0; | ||
| 518 | int32_t ne1; | ||
| 519 | uint64_t nb1; | ||
| 520 | int32_t nr0; | ||
| 521 | } ggml_metal_kargs_mul_mv_id; | ||
| 522 | |||
| 523 | // NORM | ||
| 524 | // RMS_NORM | ||
| 525 | typedef struct { | ||
| 526 | int32_t ne00; | ||
| 527 | int32_t ne00_t; | ||
| 528 | uint64_t nb1; | ||
| 529 | uint64_t nb2; | ||
| 530 | uint64_t nb3; | ||
| 531 | float eps; | ||
| 532 | int32_t nef1[3]; | ||
| 533 | int32_t nef2[3]; | ||
| 534 | int32_t nef3[3]; | ||
| 535 | uint64_t nbf1[3]; | ||
| 536 | uint64_t nbf2[3]; | ||
| 537 | uint64_t nbf3[3]; | ||
| 538 | } ggml_metal_kargs_norm; | ||
| 539 | |||
| 540 | typedef struct { | ||
| 541 | int32_t ne00; | ||
| 542 | int32_t ne01; | ||
| 543 | int32_t ne02; | ||
| 544 | int32_t ne03; | ||
| 545 | uint64_t nb00; | ||
| 546 | uint64_t nb01; | ||
| 547 | uint64_t nb02; | ||
| 548 | uint64_t nb03; | ||
| 549 | int32_t ne0; | ||
| 550 | int32_t ne1; | ||
| 551 | int32_t ne2; | ||
| 552 | int32_t ne3; | ||
| 553 | uint64_t nb0; | ||
| 554 | uint64_t nb1; | ||
| 555 | uint64_t nb2; | ||
| 556 | uint64_t nb3; | ||
| 557 | float eps; | ||
| 558 | } ggml_metal_kargs_l2_norm; | ||
| 559 | |||
| 560 | typedef struct { | ||
| 561 | int64_t ne00; | ||
| 562 | int64_t ne01; | ||
| 563 | int64_t ne02; | ||
| 564 | uint64_t nb00; | ||
| 565 | uint64_t nb01; | ||
| 566 | uint64_t nb02; | ||
| 567 | int32_t ngrp; | ||
| 568 | float eps; | ||
| 569 | } ggml_metal_kargs_group_norm; | ||
| 570 | |||
| 571 | typedef struct { | ||
| 572 | int32_t IC; | ||
| 573 | int32_t IL; | ||
| 574 | int32_t K; | ||
| 575 | int32_t s0; | ||
| 576 | uint64_t nb0; | ||
| 577 | uint64_t nb1; | ||
| 578 | } ggml_metal_kargs_conv_transpose_1d; | ||
| 579 | |||
| 580 | typedef struct { | ||
| 581 | int32_t IC; | ||
| 582 | int32_t IH; | ||
| 583 | int32_t IW; | ||
| 584 | int32_t KH; | ||
| 585 | int32_t KW; | ||
| 586 | int32_t OC; | ||
| 587 | int32_t s0; | ||
| 588 | uint64_t nb0; | ||
| 589 | uint64_t nb1; | ||
| 590 | uint64_t nb2; | ||
| 591 | } ggml_metal_kargs_conv_transpose_2d; | ||
| 592 | |||
| 593 | typedef struct { | ||
| 594 | uint64_t nb00; | ||
| 595 | uint64_t nb01; | ||
| 596 | uint64_t nb02; | ||
| 597 | uint64_t nb03; | ||
| 598 | uint64_t nb10; | ||
| 599 | uint64_t nb11; | ||
| 600 | uint64_t nb12; | ||
| 601 | uint64_t nb13; | ||
| 602 | uint64_t nb0; | ||
| 603 | uint64_t nb1; | ||
| 604 | uint64_t nb2; | ||
| 605 | uint64_t nb3; | ||
| 606 | int32_t IW; | ||
| 607 | int32_t IH; | ||
| 608 | int32_t KW; | ||
| 609 | int32_t KH; | ||
| 610 | int32_t IC; | ||
| 611 | int32_t OC; | ||
| 612 | int32_t OW; | ||
| 613 | int32_t OH; | ||
| 614 | int32_t N; | ||
| 615 | int32_t s0; | ||
| 616 | int32_t s1; | ||
| 617 | int32_t p0; | ||
| 618 | int32_t p1; | ||
| 619 | int32_t d0; | ||
| 620 | int32_t d1; | ||
| 621 | } ggml_metal_kargs_conv_2d; | ||
| 622 | |||
| 623 | typedef struct { | ||
| 624 | uint64_t ofs0; | ||
| 625 | uint64_t ofs1; | ||
| 626 | int32_t IW; | ||
| 627 | int32_t IH; | ||
| 628 | int32_t CHW; | ||
| 629 | int32_t s0; | ||
| 630 | int32_t s1; | ||
| 631 | int32_t p0; | ||
| 632 | int32_t p1; | ||
| 633 | int32_t d0; | ||
| 634 | int32_t d1; | ||
| 635 | int32_t N; | ||
| 636 | int32_t KH; | ||
| 637 | int32_t KW; | ||
| 638 | int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources | ||
| 639 | } ggml_metal_kargs_im2col; | ||
| 640 | |||
| 641 | typedef struct{ | ||
| 642 | int32_t ne00; | ||
| 643 | uint64_t nb01; | ||
| 644 | int32_t ne10; | ||
| 645 | uint64_t nb11; | ||
| 646 | int32_t ne0; | ||
| 647 | uint64_t nb1; | ||
| 648 | int32_t i00; | ||
| 649 | int32_t i10; | ||
| 650 | float alpha; | ||
| 651 | float limit; | ||
| 652 | } ggml_metal_kargs_glu; | ||
| 653 | |||
| 654 | typedef struct { | ||
| 655 | uint64_t np; | ||
| 656 | } ggml_metal_kargs_sum; | ||
| 657 | |||
| 658 | typedef struct { | ||
| 659 | int64_t ne00; | ||
| 660 | int64_t ne01; | ||
| 661 | int64_t ne02; | ||
| 662 | int64_t ne03; | ||
| 663 | uint64_t nb00; | ||
| 664 | uint64_t nb01; | ||
| 665 | uint64_t nb02; | ||
| 666 | uint64_t nb03; | ||
| 667 | int64_t ne0; | ||
| 668 | int64_t ne1; | ||
| 669 | int64_t ne2; | ||
| 670 | int64_t ne3; | ||
| 671 | uint64_t nb0; | ||
| 672 | uint64_t nb1; | ||
| 673 | uint64_t nb2; | ||
| 674 | uint64_t nb3; | ||
| 675 | } ggml_metal_kargs_sum_rows; | ||
| 676 | |||
| 677 | typedef struct { | ||
| 678 | int64_t ne00; | ||
| 679 | int64_t ne01; | ||
| 680 | int64_t ne02; | ||
| 681 | int64_t ne03; | ||
| 682 | uint64_t nb00; | ||
| 683 | uint64_t nb01; | ||
| 684 | uint64_t nb02; | ||
| 685 | uint64_t nb03; | ||
| 686 | int64_t net0; | ||
| 687 | int64_t net1; | ||
| 688 | int64_t net2; | ||
| 689 | int64_t net3; | ||
| 690 | uint64_t nbt0; | ||
| 691 | uint64_t nbt1; | ||
| 692 | uint64_t nbt2; | ||
| 693 | uint64_t nbt3; | ||
| 694 | bool outb; | ||
| 695 | } ggml_metal_kargs_cumsum_blk; | ||
| 696 | |||
| 697 | typedef struct { | ||
| 698 | int64_t ne00; | ||
| 699 | int64_t ne01; | ||
| 700 | int64_t ne02; | ||
| 701 | int64_t ne03; | ||
| 702 | uint64_t nb00; | ||
| 703 | uint64_t nb01; | ||
| 704 | uint64_t nb02; | ||
| 705 | uint64_t nb03; | ||
| 706 | int64_t net0; | ||
| 707 | int64_t net1; | ||
| 708 | int64_t net2; | ||
| 709 | int64_t net3; | ||
| 710 | uint64_t nbt0; | ||
| 711 | uint64_t nbt1; | ||
| 712 | uint64_t nbt2; | ||
| 713 | uint64_t nbt3; | ||
| 714 | } ggml_metal_kargs_cumsum_add; | ||
| 715 | |||
| 716 | typedef struct { | ||
| 717 | int32_t ne00; | ||
| 718 | int32_t ne01; | ||
| 719 | int32_t ne02; | ||
| 720 | uint64_t nb01; | ||
| 721 | uint64_t nb02; | ||
| 722 | uint64_t nb03; | ||
| 723 | int32_t ne11; | ||
| 724 | int32_t ne12; | ||
| 725 | int32_t ne13; | ||
| 726 | uint64_t nb11; | ||
| 727 | uint64_t nb12; | ||
| 728 | uint64_t nb13; | ||
| 729 | uint64_t nb1; | ||
| 730 | uint64_t nb2; | ||
| 731 | uint64_t nb3; | ||
| 732 | float scale; | ||
| 733 | float max_bias; | ||
| 734 | float m0; | ||
| 735 | float m1; | ||
| 736 | int32_t n_head_log2; | ||
| 737 | } ggml_metal_kargs_soft_max; | ||
| 738 | |||
| 739 | typedef struct { | ||
| 740 | int64_t ne00; | ||
| 741 | int64_t ne01; | ||
| 742 | int64_t ne02; | ||
| 743 | uint64_t nb00; | ||
| 744 | uint64_t nb01; | ||
| 745 | uint64_t nb02; | ||
| 746 | int64_t ne10; | ||
| 747 | int64_t ne11; | ||
| 748 | uint64_t nb10; | ||
| 749 | uint64_t nb11; | ||
| 750 | int64_t ne0; | ||
| 751 | int64_t ne1; | ||
| 752 | int64_t ne2; | ||
| 753 | uint64_t nb0; | ||
| 754 | uint64_t nb1; | ||
| 755 | uint64_t nb2; | ||
| 756 | } ggml_metal_kargs_ssm_conv; | ||
| 757 | |||
| 758 | typedef struct { | ||
| 759 | int64_t d_state; | ||
| 760 | int64_t d_inner; | ||
| 761 | int64_t n_head; | ||
| 762 | int64_t n_group; | ||
| 763 | int64_t n_seq_tokens; | ||
| 764 | int64_t n_seqs; | ||
| 765 | uint64_t s_off; | ||
| 766 | uint64_t nb00; | ||
| 767 | uint64_t nb01; | ||
| 768 | uint64_t nb02; | ||
| 769 | uint64_t nb03; | ||
| 770 | uint64_t nb10; | ||
| 771 | uint64_t nb11; | ||
| 772 | uint64_t nb12; | ||
| 773 | uint64_t ns12; | ||
| 774 | uint64_t nb13; | ||
| 775 | uint64_t nb20; | ||
| 776 | uint64_t nb21; | ||
| 777 | uint64_t ns21; | ||
| 778 | uint64_t nb22; | ||
| 779 | int64_t ne30; | ||
| 780 | uint64_t nb31; | ||
| 781 | uint64_t nb41; | ||
| 782 | uint64_t nb42; | ||
| 783 | uint64_t ns42; | ||
| 784 | uint64_t nb43; | ||
| 785 | uint64_t nb51; | ||
| 786 | uint64_t nb52; | ||
| 787 | uint64_t ns52; | ||
| 788 | uint64_t nb53; | ||
| 789 | uint64_t nb0; | ||
| 790 | } ggml_metal_kargs_ssm_scan; | ||
| 791 | |||
| 792 | typedef struct { | ||
| 793 | int32_t ne00; | ||
| 794 | int32_t ne01; | ||
| 795 | int32_t ne02; | ||
| 796 | int32_t ne03; | ||
| 797 | uint64_t nb00; | ||
| 798 | uint64_t nb01; | ||
| 799 | uint64_t nb02; | ||
| 800 | uint64_t nb03; | ||
| 801 | int32_t ne10; | ||
| 802 | int32_t ne11; | ||
| 803 | int32_t ne12; | ||
| 804 | int32_t ne13; | ||
| 805 | uint64_t nb10; | ||
| 806 | uint64_t nb11; | ||
| 807 | uint64_t nb12; | ||
| 808 | uint64_t nb13; | ||
| 809 | int32_t ne0; | ||
| 810 | int32_t ne1; | ||
| 811 | int32_t ne2; | ||
| 812 | int32_t ne3; | ||
| 813 | uint64_t nb0; | ||
| 814 | uint64_t nb1; | ||
| 815 | uint64_t nb2; | ||
| 816 | uint64_t nb3; | ||
| 817 | } ggml_metal_kargs_solve_tri; | ||
| 818 | |||
| 819 | typedef struct { | ||
| 820 | int32_t ne00t; | ||
| 821 | int32_t ne00; | ||
| 822 | uint64_t nb01; | ||
| 823 | uint64_t nb02; | ||
| 824 | uint64_t nb03; | ||
| 825 | int32_t ne10; | ||
| 826 | uint64_t nb10; | ||
| 827 | uint64_t nb11; | ||
| 828 | uint64_t nb12; | ||
| 829 | uint64_t nb1; | ||
| 830 | uint64_t nb2; | ||
| 831 | uint64_t nb3; | ||
| 832 | } ggml_metal_kargs_get_rows; | ||
| 833 | |||
| 834 | typedef struct { | ||
| 835 | int32_t nk0; | ||
| 836 | int32_t ne01; | ||
| 837 | uint64_t nb01; | ||
| 838 | uint64_t nb02; | ||
| 839 | uint64_t nb03; | ||
| 840 | int32_t ne11; | ||
| 841 | int32_t ne12; | ||
| 842 | uint64_t nb10; | ||
| 843 | uint64_t nb11; | ||
| 844 | uint64_t nb12; | ||
| 845 | uint64_t nb1; | ||
| 846 | uint64_t nb2; | ||
| 847 | uint64_t nb3; | ||
| 848 | } ggml_metal_kargs_set_rows; | ||
| 849 | |||
| 850 | typedef struct { | ||
| 851 | int32_t ne00; | ||
| 852 | int32_t ne01; | ||
| 853 | int32_t ne02; | ||
| 854 | int32_t ne03; | ||
| 855 | uint64_t nb00; | ||
| 856 | uint64_t nb01; | ||
| 857 | uint64_t nb02; | ||
| 858 | uint64_t nb03; | ||
| 859 | int32_t ne0; | ||
| 860 | int32_t ne1; | ||
| 861 | int32_t ne2; | ||
| 862 | int32_t ne3; | ||
| 863 | uint64_t nb0; | ||
| 864 | uint64_t nb1; | ||
| 865 | uint64_t nb2; | ||
| 866 | uint64_t nb3; | ||
| 867 | } ggml_metal_kargs_diag; | ||
| 868 | |||
| 869 | typedef struct { | ||
| 870 | int64_t ne00; | ||
| 871 | int64_t ne01; | ||
| 872 | int64_t ne02; | ||
| 873 | int64_t ne03; | ||
| 874 | uint64_t nb00; | ||
| 875 | uint64_t nb01; | ||
| 876 | uint64_t nb02; | ||
| 877 | uint64_t nb03; | ||
| 878 | int64_t ne0; | ||
| 879 | int64_t ne1; | ||
| 880 | int64_t ne2; | ||
| 881 | int64_t ne3; | ||
| 882 | uint64_t nb0; | ||
| 883 | uint64_t nb1; | ||
| 884 | uint64_t nb2; | ||
| 885 | uint64_t nb3; | ||
| 886 | float sf0; | ||
| 887 | float sf1; | ||
| 888 | float sf2; | ||
| 889 | float sf3; | ||
| 890 | } ggml_metal_kargs_upscale; | ||
| 891 | |||
| 892 | typedef struct { | ||
| 893 | int64_t ne00; | ||
| 894 | int64_t ne01; | ||
| 895 | int64_t ne02; | ||
| 896 | int64_t ne03; | ||
| 897 | uint64_t nb00; | ||
| 898 | uint64_t nb01; | ||
| 899 | uint64_t nb02; | ||
| 900 | uint64_t nb03; | ||
| 901 | int64_t ne0; | ||
| 902 | int64_t ne1; | ||
| 903 | int64_t ne2; | ||
| 904 | int64_t ne3; | ||
| 905 | uint64_t nb0; | ||
| 906 | uint64_t nb1; | ||
| 907 | uint64_t nb2; | ||
| 908 | uint64_t nb3; | ||
| 909 | } ggml_metal_kargs_pad; | ||
| 910 | |||
| 911 | typedef struct { | ||
| 912 | int64_t ne00; | ||
| 913 | int64_t ne01; | ||
| 914 | int64_t ne02; | ||
| 915 | int64_t ne03; | ||
| 916 | uint64_t nb00; | ||
| 917 | uint64_t nb01; | ||
| 918 | uint64_t nb02; | ||
| 919 | uint64_t nb03; | ||
| 920 | int64_t ne0; | ||
| 921 | int64_t ne1; | ||
| 922 | int64_t ne2; | ||
| 923 | int64_t ne3; | ||
| 924 | uint64_t nb0; | ||
| 925 | uint64_t nb1; | ||
| 926 | uint64_t nb2; | ||
| 927 | uint64_t nb3; | ||
| 928 | int32_t p0; | ||
| 929 | int32_t p1; | ||
| 930 | } ggml_metal_kargs_pad_reflect_1d; | ||
| 931 | |||
| 932 | typedef struct { | ||
| 933 | uint64_t nb1; | ||
| 934 | int dim; | ||
| 935 | int max_period; | ||
| 936 | } ggml_metal_kargs_timestep_embedding; | ||
| 937 | |||
| 938 | typedef struct { | ||
| 939 | int32_t ne00; | ||
| 940 | int32_t ne01; | ||
| 941 | int32_t ne02; | ||
| 942 | int32_t ne03; | ||
| 943 | uint64_t nb00; | ||
| 944 | uint64_t nb01; | ||
| 945 | uint64_t nb02; | ||
| 946 | uint64_t nb03; | ||
| 947 | int32_t ne0; | ||
| 948 | int32_t ne1; | ||
| 949 | int32_t ne2; | ||
| 950 | int32_t ne3; | ||
| 951 | uint64_t nb0; | ||
| 952 | uint64_t nb1; | ||
| 953 | uint64_t nb2; | ||
| 954 | uint64_t nb3; | ||
| 955 | } ggml_metal_kargs_tri; | ||
| 956 | |||
| 957 | typedef struct { | ||
| 958 | int32_t ne00; | ||
| 959 | int32_t ne01; | ||
| 960 | int32_t ne02; | ||
| 961 | int32_t ne03; | ||
| 962 | uint64_t nb00; | ||
| 963 | uint64_t nb01; | ||
| 964 | uint64_t nb02; | ||
| 965 | uint64_t nb03; | ||
| 966 | int32_t ne0; | ||
| 967 | int32_t ne1; | ||
| 968 | int32_t ne2; | ||
| 969 | int32_t ne3; | ||
| 970 | int32_t top_k; | ||
| 971 | } ggml_metal_kargs_argsort; | ||
| 972 | |||
| 973 | typedef struct { | ||
| 974 | int64_t ne00; | ||
| 975 | int64_t ne01; | ||
| 976 | int64_t ne02; | ||
| 977 | int64_t ne03; | ||
| 978 | uint64_t nb00; | ||
| 979 | uint64_t nb01; | ||
| 980 | uint64_t nb02; | ||
| 981 | uint64_t nb03; | ||
| 982 | int32_t ne0; | ||
| 983 | int32_t ne1; | ||
| 984 | int32_t ne2; | ||
| 985 | int32_t ne3; | ||
| 986 | int32_t top_k; | ||
| 987 | int32_t len; | ||
| 988 | } ggml_metal_kargs_argsort_merge; | ||
| 989 | |||
| 990 | typedef struct { | ||
| 991 | int64_t ne0; | ||
| 992 | float start; | ||
| 993 | float step; | ||
| 994 | } ggml_metal_kargs_arange; | ||
| 995 | |||
| 996 | typedef struct { | ||
| 997 | int64_t val; | ||
| 998 | } ggml_metal_kargs_memset; | ||
| 999 | |||
| 1000 | typedef struct { | ||
| 1001 | int32_t ne00; | ||
| 1002 | int32_t ne01; | ||
| 1003 | int32_t ne02; | ||
| 1004 | int32_t ne03; | ||
| 1005 | uint64_t nb00; | ||
| 1006 | uint64_t nb01; | ||
| 1007 | uint64_t nb02; | ||
| 1008 | uint64_t nb03; | ||
| 1009 | uint64_t nb10; | ||
| 1010 | uint64_t nb11; | ||
| 1011 | uint64_t nb12; | ||
| 1012 | uint64_t nb13; | ||
| 1013 | } ggml_metal_kargs_count_equal; | ||
| 1014 | |||
| 1015 | typedef struct { | ||
| 1016 | int32_t k0; | ||
| 1017 | int32_t k1; | ||
| 1018 | int32_t s0; | ||
| 1019 | int32_t s1; | ||
| 1020 | int32_t p0; | ||
| 1021 | int32_t p1; | ||
| 1022 | int64_t IH; | ||
| 1023 | int64_t IW; | ||
| 1024 | int64_t OH; | ||
| 1025 | int64_t OW; | ||
| 1026 | int64_t np; | ||
| 1027 | } ggml_metal_kargs_pool_2d; | ||
| 1028 | |||
| 1029 | typedef struct { | ||
| 1030 | int32_t k0; | ||
| 1031 | int32_t s0; | ||
| 1032 | int32_t p0; | ||
| 1033 | int64_t IW; | ||
| 1034 | int64_t OW; | ||
| 1035 | int64_t np; | ||
| 1036 | } ggml_metal_kargs_pool_1d; | ||
| 1037 | |||
| 1038 | typedef struct { | ||
| 1039 | int64_t ne00; | ||
| 1040 | uint64_t nb01; | ||
| 1041 | } ggml_metal_kargs_argmax; | ||
| 1042 | |||
| 1043 | typedef struct { | ||
| 1044 | int64_t np; | ||
| 1045 | } ggml_metal_kargs_opt_step_adamw; | ||
| 1046 | |||
| 1047 | typedef struct { | ||
| 1048 | int64_t np; | ||
| 1049 | } ggml_metal_kargs_opt_step_sgd; | ||
| 1050 | |||
| 1051 | #endif // GGML_METAL_IMPL | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp new file mode 100644 index 0000000..7db95d1 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.cpp | |||
| @@ -0,0 +1,4222 @@ | |||
| 1 | #include "ggml-metal-ops.h" | ||
| 2 | |||
| 3 | #include "ggml.h" | ||
| 4 | #include "ggml-impl.h" | ||
| 5 | #include "ggml-backend-impl.h" | ||
| 6 | |||
| 7 | #include "ggml-metal-impl.h" | ||
| 8 | #include "ggml-metal-common.h" | ||
| 9 | #include "ggml-metal-device.h" | ||
| 10 | |||
| 11 | #include <cassert> | ||
| 12 | #include <algorithm> | ||
| 13 | #include <limits> | ||
| 14 | #include <cmath> | ||
| 15 | |||
| 16 | static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) { | ||
| 17 | if (!t) { | ||
| 18 | return { nullptr, 0 }; | ||
| 19 | } | ||
| 20 | |||
| 21 | ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer; | ||
| 22 | |||
| 23 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t) buffer->context; | ||
| 24 | |||
| 25 | return ggml_metal_buffer_get_id(ctx, t); | ||
| 26 | } | ||
| 27 | |||
| 28 | struct ggml_metal_op { | ||
| 29 | ggml_metal_op( | ||
| 30 | ggml_metal_device_t dev, | ||
| 31 | ggml_metal_cmd_buf_t cmd_buf, | ||
| 32 | ggml_cgraph * gf, | ||
| 33 | int idx_start, | ||
| 34 | int idx_end, | ||
| 35 | bool use_fusion, | ||
| 36 | bool use_concurrency, | ||
| 37 | bool use_capture, | ||
| 38 | int debug_graph, | ||
| 39 | int debug_fusion) { | ||
| 40 | this->dev = dev; | ||
| 41 | this->lib = ggml_metal_device_get_library(dev); | ||
| 42 | this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency); | ||
| 43 | this->mem_ranges = ggml_mem_ranges_init(debug_graph); | ||
| 44 | this->idx_start = idx_start; | ||
| 45 | this->idx_end = idx_end; | ||
| 46 | this->use_fusion = use_fusion; | ||
| 47 | this->use_concurrency = use_concurrency; | ||
| 48 | this->use_capture = use_capture; | ||
| 49 | this->debug_graph = debug_graph; | ||
| 50 | this->debug_fusion = debug_fusion; | ||
| 51 | this->gf = gf; | ||
| 52 | |||
| 53 | idxs.reserve(gf->n_nodes); | ||
| 54 | |||
| 55 | // filter empty nodes | ||
| 56 | // TODO: this can be removed when the allocator starts filtering them earlier | ||
| 57 | // https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830 | ||
| 58 | for (int i = idx_start; i < idx_end; i++) { | ||
| 59 | if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) { | ||
| 60 | idxs.push_back(i); | ||
| 61 | } | ||
| 62 | } | ||
| 63 | } | ||
| 64 | |||
| 65 | ~ggml_metal_op() { | ||
| 66 | ggml_metal_encoder_end_encoding(this->enc); | ||
| 67 | ggml_metal_encoder_free(this->enc); | ||
| 68 | ggml_mem_ranges_free(this->mem_ranges); | ||
| 69 | } | ||
| 70 | |||
| 71 | int n_nodes() const { | ||
| 72 | return idxs.size(); | ||
| 73 | } | ||
| 74 | |||
| 75 | ggml_tensor * node(int i) const { | ||
| 76 | assert(i >= 0 && i < (int) idxs.size()); | ||
| 77 | return ggml_graph_node(gf, idxs[i]); | ||
| 78 | } | ||
| 79 | |||
| 80 | bool can_fuse(int i0, const ggml_op * ops, int n_ops) const { | ||
| 81 | assert(use_fusion); | ||
| 82 | assert(i0 >= 0 && i0 < n_nodes()); | ||
| 83 | |||
| 84 | if (i0 + n_ops > n_nodes()) { | ||
| 85 | return false; | ||
| 86 | } | ||
| 87 | |||
| 88 | return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops); | ||
| 89 | } | ||
| 90 | |||
| 91 | ggml_metal_device_t dev; | ||
| 92 | ggml_metal_library_t lib; | ||
| 93 | ggml_metal_encoder_t enc; | ||
| 94 | ggml_mem_ranges_t mem_ranges; | ||
| 95 | |||
| 96 | bool use_fusion; | ||
| 97 | bool use_concurrency; | ||
| 98 | bool use_capture; | ||
| 99 | |||
| 100 | int debug_graph; | ||
| 101 | int debug_fusion; | ||
| 102 | |||
| 103 | private: | ||
| 104 | ggml_cgraph * gf; | ||
| 105 | |||
| 106 | int idx_start; | ||
| 107 | int idx_end; | ||
| 108 | |||
| 109 | // non-empty node indices | ||
| 110 | std::vector<int> idxs; | ||
| 111 | }; | ||
| 112 | |||
| 113 | ggml_metal_op_t ggml_metal_op_init( | ||
| 114 | ggml_metal_device_t dev, | ||
| 115 | ggml_metal_cmd_buf_t cmd_buf, | ||
| 116 | ggml_cgraph * gf, | ||
| 117 | int idx_start, | ||
| 118 | int idx_end, | ||
| 119 | bool use_fusion, | ||
| 120 | bool use_concurrency, | ||
| 121 | bool use_capture, | ||
| 122 | int debug_graph, | ||
| 123 | int debug_fusion) { | ||
| 124 | ggml_metal_op_t res = new ggml_metal_op( | ||
| 125 | dev, | ||
| 126 | cmd_buf, | ||
| 127 | gf, | ||
| 128 | idx_start, | ||
| 129 | idx_end, | ||
| 130 | use_fusion, | ||
| 131 | use_concurrency, | ||
| 132 | use_capture, | ||
| 133 | debug_graph, | ||
| 134 | debug_fusion); | ||
| 135 | |||
| 136 | return res; | ||
| 137 | } | ||
| 138 | |||
| 139 | void ggml_metal_op_free(ggml_metal_op_t ctx) { | ||
| 140 | delete ctx; | ||
| 141 | } | ||
| 142 | |||
| 143 | int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) { | ||
| 144 | return ctx->n_nodes(); | ||
| 145 | } | ||
| 146 | |||
| 147 | static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) { | ||
| 148 | if (!ctx->mem_ranges) { | ||
| 149 | return true; | ||
| 150 | } | ||
| 151 | |||
| 152 | ggml_metal_encoder_memory_barrier(ctx->enc); | ||
| 153 | |||
| 154 | ggml_mem_ranges_reset(ctx->mem_ranges); | ||
| 155 | |||
| 156 | return true; | ||
| 157 | } | ||
| 158 | |||
| 159 | static bool ggml_metal_op_concurrency_check(ggml_metal_op_t ctx, const ggml_tensor * node) { | ||
| 160 | if (!ctx->mem_ranges) { | ||
| 161 | return false; | ||
| 162 | } | ||
| 163 | |||
| 164 | return ggml_mem_ranges_check(ctx->mem_ranges, node); | ||
| 165 | } | ||
| 166 | |||
| 167 | static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor * node) { | ||
| 168 | if (!ctx->mem_ranges) { | ||
| 169 | return true; | ||
| 170 | } | ||
| 171 | |||
| 172 | return ggml_mem_ranges_add(ctx->mem_ranges, node); | ||
| 173 | } | ||
| 174 | |||
| 175 | static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { | ||
| 176 | struct ggml_tensor * node = ctx->node(idx); | ||
| 177 | |||
| 178 | //GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op)); | ||
| 179 | |||
| 180 | if (ggml_is_empty(node)) { | ||
| 181 | return 1; | ||
| 182 | } | ||
| 183 | |||
| 184 | switch (node->op) { | ||
| 185 | case GGML_OP_NONE: | ||
| 186 | case GGML_OP_RESHAPE: | ||
| 187 | case GGML_OP_VIEW: | ||
| 188 | case GGML_OP_TRANSPOSE: | ||
| 189 | case GGML_OP_PERMUTE: | ||
| 190 | { | ||
| 191 | // noop -> next node | ||
| 192 | if (ctx->debug_graph > 0) { | ||
| 193 | GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)"); | ||
| 194 | } | ||
| 195 | } return 1; | ||
| 196 | default: | ||
| 197 | { | ||
| 198 | } break; | ||
| 199 | } | ||
| 200 | |||
| 201 | if (!ggml_metal_device_supports_op(ctx->dev, node)) { | ||
| 202 | GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(node)); | ||
| 203 | GGML_ABORT("unsupported op"); | ||
| 204 | } | ||
| 205 | |||
| 206 | if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { | ||
| 207 | return 1; | ||
| 208 | } | ||
| 209 | |||
| 210 | int n_fuse = 1; | ||
| 211 | |||
| 212 | // check if the current node can run concurrently with other nodes before it | ||
| 213 | // the condition is that: | ||
| 214 | // - the current node cannot write to any previous src or dst ranges | ||
| 215 | // - the current node cannot read from any previous dst ranges | ||
| 216 | // | ||
| 217 | // if the condition is not satisfied, we put a memory barrier and clear all ranges | ||
| 218 | // otherwise, we add the new ranges to the encoding context and process the node concurrently | ||
| 219 | // | ||
| 220 | { | ||
| 221 | const bool is_concurrent = ggml_metal_op_concurrency_check(ctx, node); | ||
| 222 | |||
| 223 | if (!is_concurrent) { | ||
| 224 | ggml_metal_op_concurrency_reset(ctx); | ||
| 225 | } | ||
| 226 | |||
| 227 | if (ctx->debug_graph > 0) { | ||
| 228 | GGML_LOG_DEBUG("%s: node[%5d] - %-12s %-12s %s\n", __func__, idx, ggml_op_name(node->op), ggml_get_name(node), is_concurrent ? "(concurrent)" : ""); | ||
| 229 | } | ||
| 230 | if (ctx->debug_graph > 1) { | ||
| 231 | GGML_TENSOR_LOCALS( int64_t, ne0, node->src[0], ne); | ||
| 232 | GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); | ||
| 233 | GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); | ||
| 234 | GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); | ||
| 235 | GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); | ||
| 236 | GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); | ||
| 237 | GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); | ||
| 238 | GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); | ||
| 239 | GGML_TENSOR_LOCALS( int64_t, ne, node, ne); | ||
| 240 | GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); | ||
| 241 | |||
| 242 | if (node->src[0]) { | ||
| 243 | GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[0]->type), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03, | ||
| 244 | ggml_is_contiguous(node->src[0]), node->src[0]->name); | ||
| 245 | } | ||
| 246 | if (node->src[1]) { | ||
| 247 | GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, | ||
| 248 | ggml_is_contiguous(node->src[1]), node->src[1]->name); | ||
| 249 | } | ||
| 250 | if (node->src[2]) { | ||
| 251 | GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, | ||
| 252 | ggml_is_contiguous(node->src[2]), node->src[2]->name); | ||
| 253 | } | ||
| 254 | if (node->src[3]) { | ||
| 255 | GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, | ||
| 256 | ggml_is_contiguous(node->src[3]), node->src[3]->name); | ||
| 257 | } | ||
| 258 | if (node) { | ||
| 259 | GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, | ||
| 260 | node->name); | ||
| 261 | } | ||
| 262 | } | ||
| 263 | } | ||
| 264 | |||
| 265 | switch (node->op) { | ||
| 266 | case GGML_OP_CONCAT: | ||
| 267 | { | ||
| 268 | n_fuse = ggml_metal_op_concat(ctx, idx); | ||
| 269 | } break; | ||
| 270 | case GGML_OP_ADD: | ||
| 271 | case GGML_OP_SUB: | ||
| 272 | case GGML_OP_MUL: | ||
| 273 | case GGML_OP_DIV: | ||
| 274 | { | ||
| 275 | n_fuse = ggml_metal_op_bin(ctx, idx); | ||
| 276 | } break; | ||
| 277 | case GGML_OP_ADD_ID: | ||
| 278 | { | ||
| 279 | n_fuse = ggml_metal_op_add_id(ctx, idx); | ||
| 280 | } break; | ||
| 281 | case GGML_OP_REPEAT: | ||
| 282 | { | ||
| 283 | n_fuse = ggml_metal_op_repeat(ctx, idx); | ||
| 284 | } break; | ||
| 285 | case GGML_OP_ACC: | ||
| 286 | { | ||
| 287 | n_fuse = ggml_metal_op_acc(ctx, idx); | ||
| 288 | } break; | ||
| 289 | case GGML_OP_SCALE: | ||
| 290 | case GGML_OP_FILL: | ||
| 291 | case GGML_OP_CLAMP: | ||
| 292 | case GGML_OP_LEAKY_RELU: | ||
| 293 | case GGML_OP_SQR: | ||
| 294 | case GGML_OP_SQRT: | ||
| 295 | case GGML_OP_SIN: | ||
| 296 | case GGML_OP_COS: | ||
| 297 | case GGML_OP_LOG: | ||
| 298 | case GGML_OP_UNARY: | ||
| 299 | { | ||
| 300 | n_fuse = ggml_metal_op_unary(ctx, idx); | ||
| 301 | } break; | ||
| 302 | case GGML_OP_GLU: | ||
| 303 | { | ||
| 304 | n_fuse = ggml_metal_op_glu(ctx, idx); | ||
| 305 | } break; | ||
| 306 | case GGML_OP_SUM: | ||
| 307 | { | ||
| 308 | n_fuse = ggml_metal_op_sum(ctx, idx); | ||
| 309 | } break; | ||
| 310 | case GGML_OP_SUM_ROWS: | ||
| 311 | case GGML_OP_MEAN: | ||
| 312 | { | ||
| 313 | n_fuse = ggml_metal_op_sum_rows(ctx, idx); | ||
| 314 | } break; | ||
| 315 | case GGML_OP_CUMSUM: | ||
| 316 | { | ||
| 317 | n_fuse = ggml_metal_op_cumsum(ctx, idx); | ||
| 318 | } break; | ||
| 319 | case GGML_OP_SOFT_MAX: | ||
| 320 | { | ||
| 321 | n_fuse = ggml_metal_op_soft_max(ctx, idx); | ||
| 322 | } break; | ||
| 323 | case GGML_OP_SSM_CONV: | ||
| 324 | { | ||
| 325 | n_fuse = ggml_metal_op_ssm_conv(ctx, idx); | ||
| 326 | } break; | ||
| 327 | case GGML_OP_SSM_SCAN: | ||
| 328 | { | ||
| 329 | n_fuse = ggml_metal_op_ssm_scan(ctx, idx); | ||
| 330 | } break; | ||
| 331 | case GGML_OP_RWKV_WKV6: | ||
| 332 | case GGML_OP_RWKV_WKV7: | ||
| 333 | { | ||
| 334 | n_fuse = ggml_metal_op_rwkv(ctx, idx); | ||
| 335 | } break; | ||
| 336 | case GGML_OP_SOLVE_TRI: | ||
| 337 | { | ||
| 338 | n_fuse = ggml_metal_op_solve_tri(ctx, idx); | ||
| 339 | } break; | ||
| 340 | case GGML_OP_MUL_MAT: | ||
| 341 | { | ||
| 342 | n_fuse = ggml_metal_op_mul_mat(ctx, idx); | ||
| 343 | } break; | ||
| 344 | case GGML_OP_MUL_MAT_ID: | ||
| 345 | { | ||
| 346 | n_fuse = ggml_metal_op_mul_mat_id(ctx, idx); | ||
| 347 | } break; | ||
| 348 | case GGML_OP_GET_ROWS: | ||
| 349 | { | ||
| 350 | n_fuse = ggml_metal_op_get_rows(ctx, idx); | ||
| 351 | } break; | ||
| 352 | case GGML_OP_SET_ROWS: | ||
| 353 | { | ||
| 354 | n_fuse = ggml_metal_op_set_rows(ctx, idx); | ||
| 355 | } break; | ||
| 356 | case GGML_OP_DIAG: | ||
| 357 | { | ||
| 358 | n_fuse = ggml_metal_op_diag(ctx, idx); | ||
| 359 | } break; | ||
| 360 | case GGML_OP_L2_NORM: | ||
| 361 | { | ||
| 362 | n_fuse = ggml_metal_op_l2_norm(ctx, idx); | ||
| 363 | } break; | ||
| 364 | case GGML_OP_GROUP_NORM: | ||
| 365 | { | ||
| 366 | n_fuse = ggml_metal_op_group_norm(ctx, idx); | ||
| 367 | } break; | ||
| 368 | case GGML_OP_NORM: | ||
| 369 | case GGML_OP_RMS_NORM: | ||
| 370 | { | ||
| 371 | n_fuse = ggml_metal_op_norm(ctx, idx); | ||
| 372 | } break; | ||
| 373 | case GGML_OP_ROPE: | ||
| 374 | { | ||
| 375 | n_fuse = ggml_metal_op_rope(ctx, idx); | ||
| 376 | } break; | ||
| 377 | case GGML_OP_IM2COL: | ||
| 378 | { | ||
| 379 | n_fuse = ggml_metal_op_im2col(ctx, idx); | ||
| 380 | } break; | ||
| 381 | case GGML_OP_CONV_2D: | ||
| 382 | { | ||
| 383 | n_fuse = ggml_metal_op_conv_2d(ctx, idx); | ||
| 384 | } break; | ||
| 385 | case GGML_OP_CONV_TRANSPOSE_1D: | ||
| 386 | { | ||
| 387 | n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx); | ||
| 388 | } break; | ||
| 389 | case GGML_OP_CONV_TRANSPOSE_2D: | ||
| 390 | { | ||
| 391 | n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); | ||
| 392 | } break; | ||
| 393 | case GGML_OP_UPSCALE: | ||
| 394 | { | ||
| 395 | n_fuse = ggml_metal_op_upscale(ctx, idx); | ||
| 396 | } break; | ||
| 397 | case GGML_OP_PAD: | ||
| 398 | { | ||
| 399 | n_fuse = ggml_metal_op_pad(ctx, idx); | ||
| 400 | } break; | ||
| 401 | case GGML_OP_PAD_REFLECT_1D: | ||
| 402 | { | ||
| 403 | n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); | ||
| 404 | } break; | ||
| 405 | case GGML_OP_ARANGE: | ||
| 406 | { | ||
| 407 | n_fuse = ggml_metal_op_arange(ctx, idx); | ||
| 408 | } break; | ||
| 409 | case GGML_OP_TIMESTEP_EMBEDDING: | ||
| 410 | { | ||
| 411 | n_fuse = ggml_metal_op_timestep_embedding(ctx, idx); | ||
| 412 | } break; | ||
| 413 | case GGML_OP_ARGSORT: | ||
| 414 | { | ||
| 415 | n_fuse = ggml_metal_op_argsort(ctx, idx); | ||
| 416 | } break; | ||
| 417 | case GGML_OP_TOP_K: | ||
| 418 | { | ||
| 419 | n_fuse = ggml_metal_op_top_k(ctx, idx); | ||
| 420 | } break; | ||
| 421 | case GGML_OP_TRI: | ||
| 422 | { | ||
| 423 | n_fuse = ggml_metal_op_tri(ctx, idx); | ||
| 424 | } break; | ||
| 425 | case GGML_OP_FLASH_ATTN_EXT: | ||
| 426 | { | ||
| 427 | n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); | ||
| 428 | } break; | ||
| 429 | case GGML_OP_DUP: | ||
| 430 | case GGML_OP_CPY: | ||
| 431 | case GGML_OP_CONT: | ||
| 432 | { | ||
| 433 | n_fuse = ggml_metal_op_cpy(ctx, idx); | ||
| 434 | } break; | ||
| 435 | case GGML_OP_POOL_1D: | ||
| 436 | { | ||
| 437 | n_fuse = ggml_metal_op_pool_1d(ctx, idx); | ||
| 438 | } break; | ||
| 439 | case GGML_OP_POOL_2D: | ||
| 440 | { | ||
| 441 | n_fuse = ggml_metal_op_pool_2d(ctx, idx); | ||
| 442 | } break; | ||
| 443 | case GGML_OP_ARGMAX: | ||
| 444 | { | ||
| 445 | n_fuse = ggml_metal_op_argmax(ctx, idx); | ||
| 446 | } break; | ||
| 447 | case GGML_OP_OPT_STEP_ADAMW: | ||
| 448 | { | ||
| 449 | n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); | ||
| 450 | } break; | ||
| 451 | case GGML_OP_OPT_STEP_SGD: | ||
| 452 | { | ||
| 453 | n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); | ||
| 454 | } break; | ||
| 455 | case GGML_OP_COUNT_EQUAL: | ||
| 456 | { | ||
| 457 | n_fuse = ggml_metal_op_count_equal(ctx, idx); | ||
| 458 | } break; | ||
| 459 | default: | ||
| 460 | { | ||
| 461 | GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); | ||
| 462 | GGML_ABORT("fatal error"); | ||
| 463 | } | ||
| 464 | } | ||
| 465 | |||
| 466 | if (ctx->debug_graph > 0) { | ||
| 467 | if (n_fuse > 1) { | ||
| 468 | GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse); | ||
| 469 | } | ||
| 470 | } | ||
| 471 | |||
| 472 | // update the mem ranges in the encoding context | ||
| 473 | for (int i = 0; i < n_fuse; ++i) { | ||
| 474 | if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) { | ||
| 475 | ggml_metal_op_concurrency_reset(ctx); | ||
| 476 | } | ||
| 477 | } | ||
| 478 | |||
| 479 | return n_fuse; | ||
| 480 | } | ||
| 481 | |||
| 482 | int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) { | ||
| 483 | if (ctx->use_capture) { | ||
| 484 | ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx))); | ||
| 485 | } | ||
| 486 | |||
| 487 | int res = ggml_metal_op_encode_impl(ctx, idx); | ||
| 488 | if (idx + res > ctx->n_nodes()) { | ||
| 489 | GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", | ||
| 490 | "https://github.com/ggml-org/llama.cpp/pull/14849"); | ||
| 491 | } | ||
| 492 | |||
| 493 | if (ctx->use_capture) { | ||
| 494 | ggml_metal_encoder_debug_group_pop(ctx->enc); | ||
| 495 | } | ||
| 496 | |||
| 497 | return res; | ||
| 498 | } | ||
| 499 | |||
| 500 | int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { | ||
| 501 | ggml_tensor * op = ctx->node(idx); | ||
| 502 | |||
| 503 | ggml_metal_library_t lib = ctx->lib; | ||
| 504 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 505 | |||
| 506 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 507 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 508 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 509 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 510 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 511 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 512 | |||
| 513 | const int32_t dim = ((const int32_t *) op->op_params)[0]; | ||
| 514 | |||
| 515 | ggml_metal_kargs_concat args = { | ||
| 516 | /*.ne00 =*/ ne00, | ||
| 517 | /*.ne01 =*/ ne01, | ||
| 518 | /*.ne02 =*/ ne02, | ||
| 519 | /*.ne03 =*/ ne03, | ||
| 520 | /*.nb00 =*/ nb00, | ||
| 521 | /*.nb01 =*/ nb01, | ||
| 522 | /*.nb02 =*/ nb02, | ||
| 523 | /*.nb03 =*/ nb03, | ||
| 524 | /*.ne10 =*/ ne10, | ||
| 525 | /*.ne11 =*/ ne11, | ||
| 526 | /*.ne12 =*/ ne12, | ||
| 527 | /*.ne13 =*/ ne13, | ||
| 528 | /*.nb10 =*/ nb10, | ||
| 529 | /*.nb11 =*/ nb11, | ||
| 530 | /*.nb12 =*/ nb12, | ||
| 531 | /*.nb13 =*/ nb13, | ||
| 532 | /*.ne0 =*/ ne0, | ||
| 533 | /*.ne1 =*/ ne1, | ||
| 534 | /*.ne2 =*/ ne2, | ||
| 535 | /*.ne3 =*/ ne3, | ||
| 536 | /*.nb0 =*/ nb0, | ||
| 537 | /*.nb1 =*/ nb1, | ||
| 538 | /*.nb2 =*/ nb2, | ||
| 539 | /*.nb3 =*/ nb3, | ||
| 540 | /*.dim =*/ dim, | ||
| 541 | }; | ||
| 542 | |||
| 543 | auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT); | ||
| 544 | |||
| 545 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 546 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 547 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 548 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 549 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 550 | |||
| 551 | const int nth = std::min(1024, ne0); | ||
| 552 | |||
| 553 | ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); | ||
| 554 | |||
| 555 | return 1; | ||
| 556 | } | ||
| 557 | |||
| 558 | int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { | ||
| 559 | ggml_tensor * op = ctx->node(idx); | ||
| 560 | |||
| 561 | ggml_metal_library_t lib = ctx->lib; | ||
| 562 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 563 | |||
| 564 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 565 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 566 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 567 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 568 | |||
| 569 | auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); | ||
| 570 | |||
| 571 | ggml_metal_kargs_repeat args = { | ||
| 572 | /*.ne00 =*/ ne00, | ||
| 573 | /*.ne01 =*/ ne01, | ||
| 574 | /*.ne02 =*/ ne02, | ||
| 575 | /*.ne03 =*/ ne03, | ||
| 576 | /*.nb00 =*/ nb00, | ||
| 577 | /*.nb01 =*/ nb01, | ||
| 578 | /*.nb02 =*/ nb02, | ||
| 579 | /*.nb03 =*/ nb03, | ||
| 580 | /*.ne0 =*/ ne0, | ||
| 581 | /*.ne1 =*/ ne1, | ||
| 582 | /*.ne2 =*/ ne2, | ||
| 583 | /*.ne3 =*/ ne3, | ||
| 584 | /*.nb0 =*/ nb0, | ||
| 585 | /*.nb1 =*/ nb1, | ||
| 586 | /*.nb2 =*/ nb2, | ||
| 587 | /*.nb3 =*/ nb3, | ||
| 588 | }; | ||
| 589 | |||
| 590 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 591 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 592 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 593 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 594 | |||
| 595 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); | ||
| 596 | |||
| 597 | ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); | ||
| 598 | |||
| 599 | return 1; | ||
| 600 | } | ||
| 601 | |||
| 602 | int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { | ||
| 603 | ggml_tensor * op = ctx->node(idx); | ||
| 604 | |||
| 605 | ggml_metal_library_t lib = ctx->lib; | ||
| 606 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 607 | |||
| 608 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 609 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 610 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 611 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 612 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 613 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 614 | |||
| 615 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); | ||
| 616 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 617 | GGML_ASSERT(op->type == GGML_TYPE_F32); | ||
| 618 | |||
| 619 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 620 | GGML_ASSERT(ggml_is_contiguous(op->src[1])); | ||
| 621 | |||
| 622 | const size_t pnb1 = ((const int32_t *) op->op_params)[0]; | ||
| 623 | const size_t pnb2 = ((const int32_t *) op->op_params)[1]; | ||
| 624 | const size_t pnb3 = ((const int32_t *) op->op_params)[2]; | ||
| 625 | const size_t offs = ((const int32_t *) op->op_params)[3]; | ||
| 626 | |||
| 627 | const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; | ||
| 628 | |||
| 629 | if (!inplace) { | ||
| 630 | // run a separete kernel to cpy src->dst | ||
| 631 | // not sure how to avoid this | ||
| 632 | // TODO: make a simpler cpy_bytes kernel | ||
| 633 | |||
| 634 | //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; | ||
| 635 | auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); | ||
| 636 | |||
| 637 | ggml_metal_kargs_cpy args = { | ||
| 638 | /*.nk0 =*/ ne00, | ||
| 639 | /*.ne00 =*/ ne00, | ||
| 640 | /*.ne01 =*/ ne01, | ||
| 641 | /*.ne02 =*/ ne02, | ||
| 642 | /*.ne03 =*/ ne03, | ||
| 643 | /*.nb00 =*/ nb00, | ||
| 644 | /*.nb01 =*/ nb01, | ||
| 645 | /*.nb02 =*/ nb02, | ||
| 646 | /*.nb03 =*/ nb03, | ||
| 647 | /*.ne0 =*/ ne0, | ||
| 648 | /*.ne1 =*/ ne1, | ||
| 649 | /*.ne2 =*/ ne2, | ||
| 650 | /*.ne3 =*/ ne3, | ||
| 651 | /*.nb0 =*/ nb0, | ||
| 652 | /*.nb1 =*/ nb1, | ||
| 653 | /*.nb2 =*/ nb2, | ||
| 654 | /*.nb3 =*/ nb3, | ||
| 655 | }; | ||
| 656 | |||
| 657 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 658 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 659 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 660 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 661 | |||
| 662 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); | ||
| 663 | |||
| 664 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 665 | |||
| 666 | ggml_metal_op_concurrency_reset(ctx); | ||
| 667 | } | ||
| 668 | |||
| 669 | ggml_metal_kargs_bin args = { | ||
| 670 | /*.ne00 =*/ ne00, | ||
| 671 | /*.ne01 =*/ ne01, | ||
| 672 | /*.ne02 =*/ ne02, | ||
| 673 | /*.ne03 =*/ ne03, | ||
| 674 | /*.nb00 =*/ nb00, | ||
| 675 | /*.nb01 =*/ pnb1, | ||
| 676 | /*.nb02 =*/ pnb2, | ||
| 677 | /*.nb03 =*/ pnb3, | ||
| 678 | /*.ne10 =*/ ne10, | ||
| 679 | /*.ne11 =*/ ne11, | ||
| 680 | /*.ne12 =*/ ne12, | ||
| 681 | /*.ne13 =*/ ne13, | ||
| 682 | /*.nb10 =*/ nb10, | ||
| 683 | /*.nb11 =*/ nb11, | ||
| 684 | /*.nb12 =*/ nb12, | ||
| 685 | /*.nb13 =*/ nb13, | ||
| 686 | /*.ne0 =*/ ne0, | ||
| 687 | /*.ne1 =*/ ne1, | ||
| 688 | /*.ne2 =*/ ne2, | ||
| 689 | /*.ne3 =*/ ne3, | ||
| 690 | /*.nb0 =*/ nb0, | ||
| 691 | /*.nb1 =*/ pnb1, | ||
| 692 | /*.nb2 =*/ pnb2, | ||
| 693 | /*.nb3 =*/ pnb3, | ||
| 694 | /*.offs =*/ offs, | ||
| 695 | /*.o1 =*/ { 0 }, | ||
| 696 | }; | ||
| 697 | |||
| 698 | auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); | ||
| 699 | |||
| 700 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 701 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 702 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 703 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 704 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 705 | |||
| 706 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); | ||
| 707 | |||
| 708 | ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); | ||
| 709 | |||
| 710 | return 1; | ||
| 711 | } | ||
| 712 | |||
| 713 | int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { | ||
| 714 | ggml_tensor * op = ctx->node(idx); | ||
| 715 | |||
| 716 | ggml_metal_library_t lib = ctx->lib; | ||
| 717 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 718 | |||
| 719 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 720 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 721 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 722 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 723 | |||
| 724 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 725 | |||
| 726 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 727 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 728 | |||
| 729 | ggml_metal_kargs_unary args = { | ||
| 730 | /*.ne00 =*/ ne00, | ||
| 731 | /*.ne01 =*/ ne01, | ||
| 732 | /*.ne02 =*/ ne02, | ||
| 733 | /*.ne03 =*/ ne03, | ||
| 734 | /*.nb00 =*/ nb00, | ||
| 735 | /*.nb01 =*/ nb01, | ||
| 736 | /*.nb02 =*/ nb02, | ||
| 737 | /*.nb03 =*/ nb03, | ||
| 738 | /*.ne0 =*/ ne0, | ||
| 739 | /*.ne1 =*/ ne1, | ||
| 740 | /*.ne2 =*/ ne2, | ||
| 741 | /*.ne3 =*/ ne3, | ||
| 742 | /*.nb0 =*/ nb0, | ||
| 743 | /*.nb1 =*/ nb1, | ||
| 744 | /*.nb2 =*/ nb2, | ||
| 745 | /*.nb3 =*/ nb3, | ||
| 746 | /*.slope =*/ 0.0, | ||
| 747 | /*.scale =*/ 0.0, | ||
| 748 | /*.bias =*/ 0.0, | ||
| 749 | /*.val =*/ 0.0, | ||
| 750 | /*.min =*/ 0.0, | ||
| 751 | /*.max =*/ 0.0, | ||
| 752 | }; | ||
| 753 | |||
| 754 | if (op->op == GGML_OP_LEAKY_RELU) { | ||
| 755 | args.slope = ggml_get_op_params_f32(op, 0); | ||
| 756 | } | ||
| 757 | |||
| 758 | if (op->op == GGML_OP_SCALE) { | ||
| 759 | args.scale = ggml_get_op_params_f32(op, 0); | ||
| 760 | args.bias = ggml_get_op_params_f32(op, 1); | ||
| 761 | } | ||
| 762 | |||
| 763 | if (op->op == GGML_OP_FILL) { | ||
| 764 | args.val = ggml_get_op_params_f32(op, 0); | ||
| 765 | } | ||
| 766 | |||
| 767 | if (op->op == GGML_OP_CLAMP) { | ||
| 768 | args.min = ggml_get_op_params_f32(op, 0); | ||
| 769 | args.max = ggml_get_op_params_f32(op, 1); | ||
| 770 | } | ||
| 771 | |||
| 772 | auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); | ||
| 773 | |||
| 774 | if (pipeline.c4) { | ||
| 775 | args.ne00 = ne00/4; | ||
| 776 | args.ne0 = ne0/4; | ||
| 777 | } | ||
| 778 | |||
| 779 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 780 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 781 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 782 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 783 | |||
| 784 | if (pipeline.cnt) { | ||
| 785 | const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); | ||
| 786 | |||
| 787 | ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); | ||
| 788 | } else { | ||
| 789 | const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 790 | |||
| 791 | const int nth = MIN(args.ne00, nth_max); | ||
| 792 | |||
| 793 | const int nk0 = (args.ne00 + nth - 1)/nth; | ||
| 794 | |||
| 795 | ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); | ||
| 796 | } | ||
| 797 | |||
| 798 | return 1; | ||
| 799 | } | ||
| 800 | |||
| 801 | int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { | ||
| 802 | ggml_tensor * op = ctx->node(idx); | ||
| 803 | |||
| 804 | ggml_metal_library_t lib = ctx->lib; | ||
| 805 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 806 | |||
| 807 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 808 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 809 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 810 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 811 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 812 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 813 | |||
| 814 | if (op->src[1]) { | ||
| 815 | GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); | ||
| 816 | } | ||
| 817 | |||
| 818 | auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op); | ||
| 819 | |||
| 820 | const int32_t swp = ggml_get_op_params_i32(op, 1); | ||
| 821 | const float alpha = ggml_get_op_params_f32(op, 2); | ||
| 822 | const float limit = ggml_get_op_params_f32(op, 3); | ||
| 823 | |||
| 824 | const int32_t i00 = swp ? ne0 : 0; | ||
| 825 | const int32_t i10 = swp ? 0 : ne0; | ||
| 826 | |||
| 827 | ggml_metal_kargs_glu args = { | ||
| 828 | /*.ne00 =*/ ne00, | ||
| 829 | /*.nb01 =*/ nb01, | ||
| 830 | /*.ne10 =*/ op->src[1] ? ne10 : ne00, | ||
| 831 | /*.nb11 =*/ op->src[1] ? nb11 : nb01, | ||
| 832 | /*.ne0 =*/ ne0, | ||
| 833 | /*.nb1 =*/ nb1, | ||
| 834 | /*.i00 =*/ op->src[1] ? 0 : i00, | ||
| 835 | /*.i10 =*/ op->src[1] ? 0 : i10, | ||
| 836 | /*.alpha=*/ alpha, | ||
| 837 | /*.limit=*/ limit | ||
| 838 | }; | ||
| 839 | |||
| 840 | const int64_t nrows = ggml_nrows(op->src[0]); | ||
| 841 | |||
| 842 | const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2); | ||
| 843 | |||
| 844 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 845 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 846 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 847 | if (op->src[1]) { | ||
| 848 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 849 | } else { | ||
| 850 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 2); | ||
| 851 | } | ||
| 852 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 853 | |||
| 854 | ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); | ||
| 855 | |||
| 856 | return 1; | ||
| 857 | } | ||
| 858 | |||
| 859 | int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { | ||
| 860 | ggml_tensor * op = ctx->node(idx); | ||
| 861 | |||
| 862 | ggml_metal_library_t lib = ctx->lib; | ||
| 863 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 864 | |||
| 865 | const uint64_t n = (uint64_t) ggml_nelements(op->src[0]); | ||
| 866 | |||
| 867 | ggml_metal_kargs_sum args = { | ||
| 868 | /*.np =*/ n, | ||
| 869 | }; | ||
| 870 | |||
| 871 | auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op); | ||
| 872 | |||
| 873 | int nth = 32; // SIMD width | ||
| 874 | |||
| 875 | while (nth < (int) n && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 876 | nth *= 2; | ||
| 877 | } | ||
| 878 | |||
| 879 | nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 880 | nth = std::min(nth, (int) n); | ||
| 881 | |||
| 882 | const int nsg = (nth + 31) / 32; | ||
| 883 | |||
| 884 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 885 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 886 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 887 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 888 | |||
| 889 | ggml_metal_encoder_set_threadgroup_memory_size(enc, nsg * sizeof(float), 0); | ||
| 890 | |||
| 891 | ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); | ||
| 892 | |||
| 893 | return 1; | ||
| 894 | } | ||
| 895 | |||
| 896 | int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { | ||
| 897 | ggml_tensor * op = ctx->node(idx); | ||
| 898 | |||
| 899 | ggml_metal_library_t lib = ctx->lib; | ||
| 900 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 901 | |||
| 902 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 903 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 904 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 905 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 906 | |||
| 907 | ggml_metal_kargs_sum_rows args = { | ||
| 908 | /*.ne00 =*/ ne00, | ||
| 909 | /*.ne01 =*/ ne01, | ||
| 910 | /*.ne02 =*/ ne02, | ||
| 911 | /*.ne03 =*/ ne03, | ||
| 912 | /*.nb00 =*/ nb00, | ||
| 913 | /*.nb01 =*/ nb01, | ||
| 914 | /*.nb02 =*/ nb02, | ||
| 915 | /*.nb03 =*/ nb03, | ||
| 916 | /*.ne0 =*/ ne0, | ||
| 917 | /*.ne1 =*/ ne1, | ||
| 918 | /*.ne2 =*/ ne2, | ||
| 919 | /*.ne3 =*/ ne3, | ||
| 920 | /*.nb0 =*/ nb0, | ||
| 921 | /*.nb1 =*/ nb1, | ||
| 922 | /*.nb2 =*/ nb2, | ||
| 923 | /*.nb3 =*/ nb3, | ||
| 924 | }; | ||
| 925 | |||
| 926 | auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); | ||
| 927 | |||
| 928 | int nth = 32; // SIMD width | ||
| 929 | |||
| 930 | while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 931 | nth *= 2; | ||
| 932 | } | ||
| 933 | |||
| 934 | nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 935 | nth = std::min(nth, ne00); | ||
| 936 | |||
| 937 | const size_t smem = pipeline.smem; | ||
| 938 | |||
| 939 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 940 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 941 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 942 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 943 | |||
| 944 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 945 | |||
| 946 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 947 | |||
| 948 | return 1; | ||
| 949 | } | ||
| 950 | |||
| 951 | int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { | ||
| 952 | ggml_tensor * op = ctx->node(idx); | ||
| 953 | |||
| 954 | ggml_metal_library_t lib = ctx->lib; | ||
| 955 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 956 | |||
| 957 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 958 | |||
| 959 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 960 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 961 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 962 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 963 | |||
| 964 | auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); | ||
| 965 | |||
| 966 | int nth = 1; | ||
| 967 | while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { | ||
| 968 | nth *= 2; | ||
| 969 | } | ||
| 970 | |||
| 971 | GGML_ASSERT(ne00 <= nth*nth); | ||
| 972 | |||
| 973 | const int64_t net0 = (ne00 + nth - 1) / nth; | ||
| 974 | const int64_t net1 = ne01; | ||
| 975 | const int64_t net2 = ne02; | ||
| 976 | const int64_t net3 = ne03; | ||
| 977 | |||
| 978 | const uint64_t nbt0 = sizeof(float); | ||
| 979 | const uint64_t nbt1 = net0*nbt0; | ||
| 980 | const uint64_t nbt2 = net1*nbt1; | ||
| 981 | const uint64_t nbt3 = net2*nbt2; | ||
| 982 | |||
| 983 | const size_t smem = GGML_PAD(32*sizeof(float), 16); | ||
| 984 | |||
| 985 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 986 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 987 | |||
| 988 | ggml_metal_buffer_id bid_tmp = bid_dst; | ||
| 989 | bid_tmp.offs += ggml_nbytes(op); | ||
| 990 | |||
| 991 | { | ||
| 992 | ggml_metal_kargs_cumsum_blk args = { | ||
| 993 | /*.ne00 =*/ ne00, | ||
| 994 | /*.ne01 =*/ ne01, | ||
| 995 | /*.ne02 =*/ ne02, | ||
| 996 | /*.ne03 =*/ ne03, | ||
| 997 | /*.nb00 =*/ nb00, | ||
| 998 | /*.nb01 =*/ nb01, | ||
| 999 | /*.nb02 =*/ nb02, | ||
| 1000 | /*.nb03 =*/ nb03, | ||
| 1001 | /*.net0 =*/ net0, | ||
| 1002 | /*.net1 =*/ net1, | ||
| 1003 | /*.net2 =*/ net2, | ||
| 1004 | /*.net3 =*/ net3, | ||
| 1005 | /*.nbt0 =*/ nbt0, | ||
| 1006 | /*.nbt1 =*/ nbt1, | ||
| 1007 | /*.nbt2 =*/ nbt2, | ||
| 1008 | /*.nbt3 =*/ nbt3, | ||
| 1009 | /*.outb =*/ ne00 > nth, | ||
| 1010 | }; | ||
| 1011 | |||
| 1012 | ggml_metal_encoder_set_pipeline(enc, pipeline_blk); | ||
| 1013 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1014 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 1015 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 2); | ||
| 1016 | ggml_metal_encoder_set_buffer (enc, bid_dst, 3); | ||
| 1017 | |||
| 1018 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 1019 | |||
| 1020 | ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1); | ||
| 1021 | } | ||
| 1022 | |||
| 1023 | if (ne00 > nth) { | ||
| 1024 | ggml_metal_op_concurrency_reset(ctx); | ||
| 1025 | |||
| 1026 | { | ||
| 1027 | ggml_metal_kargs_cumsum_blk args = { | ||
| 1028 | /*.ne00 =*/ net0, | ||
| 1029 | /*.ne01 =*/ net1, | ||
| 1030 | /*.ne02 =*/ net2, | ||
| 1031 | /*.ne03 =*/ net3, | ||
| 1032 | /*.nb00 =*/ nbt0, | ||
| 1033 | /*.nb01 =*/ nbt1, | ||
| 1034 | /*.nb02 =*/ nbt2, | ||
| 1035 | /*.nb03 =*/ nbt3, | ||
| 1036 | /*.net0 =*/ net0, | ||
| 1037 | /*.net1 =*/ net1, | ||
| 1038 | /*.net2 =*/ net2, | ||
| 1039 | /*.net3 =*/ net3, | ||
| 1040 | /*.nbt0 =*/ nbt0, | ||
| 1041 | /*.nbt1 =*/ nbt1, | ||
| 1042 | /*.nbt2 =*/ nbt2, | ||
| 1043 | /*.nbt3 =*/ nbt3, | ||
| 1044 | /*.outb =*/ false, | ||
| 1045 | }; | ||
| 1046 | |||
| 1047 | ggml_metal_encoder_set_pipeline(enc, pipeline_blk); | ||
| 1048 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1049 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); | ||
| 1050 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 2); | ||
| 1051 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); | ||
| 1052 | |||
| 1053 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 1054 | |||
| 1055 | ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1); | ||
| 1056 | } | ||
| 1057 | |||
| 1058 | ggml_metal_op_concurrency_reset(ctx); | ||
| 1059 | |||
| 1060 | { | ||
| 1061 | auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); | ||
| 1062 | |||
| 1063 | ggml_metal_kargs_cumsum_add args = { | ||
| 1064 | /*.ne00 =*/ ne00, | ||
| 1065 | /*.ne01 =*/ ne01, | ||
| 1066 | /*.ne02 =*/ ne02, | ||
| 1067 | /*.ne03 =*/ ne03, | ||
| 1068 | /*.nb00 =*/ nb00, | ||
| 1069 | /*.nb01 =*/ nb01, | ||
| 1070 | /*.nb02 =*/ nb02, | ||
| 1071 | /*.nb03 =*/ nb03, | ||
| 1072 | /*.net0 =*/ net0, | ||
| 1073 | /*.net1 =*/ net1, | ||
| 1074 | /*.net2 =*/ net2, | ||
| 1075 | /*.net3 =*/ net3, | ||
| 1076 | /*.nbt0 =*/ nbt0, | ||
| 1077 | /*.nbt1 =*/ nbt1, | ||
| 1078 | /*.nbt2 =*/ nbt2, | ||
| 1079 | /*.nbt3 =*/ nbt3, | ||
| 1080 | }; | ||
| 1081 | |||
| 1082 | ggml_metal_encoder_set_pipeline(enc, pipeline_add); | ||
| 1083 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1084 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); | ||
| 1085 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 1086 | |||
| 1087 | ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1); | ||
| 1088 | } | ||
| 1089 | } | ||
| 1090 | |||
| 1091 | return 1; | ||
| 1092 | } | ||
| 1093 | |||
| 1094 | int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { | ||
| 1095 | ggml_tensor * op = ctx->node(idx); | ||
| 1096 | |||
| 1097 | ggml_metal_library_t lib = ctx->lib; | ||
| 1098 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1099 | |||
| 1100 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1101 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1102 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 1103 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 1104 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1105 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1106 | |||
| 1107 | auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); | ||
| 1108 | |||
| 1109 | ggml_metal_kargs_get_rows args = { | ||
| 1110 | /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, | ||
| 1111 | /*.ne00 =*/ ne00, | ||
| 1112 | /*.nb01 =*/ nb01, | ||
| 1113 | /*.nb02 =*/ nb02, | ||
| 1114 | /*.nb03 =*/ nb03, | ||
| 1115 | /*.ne10 =*/ ne10, | ||
| 1116 | /*.nb10 =*/ nb10, | ||
| 1117 | /*.nb11 =*/ nb11, | ||
| 1118 | /*.nb12 =*/ nb12, | ||
| 1119 | /*.nb1 =*/ nb1, | ||
| 1120 | /*.nb2 =*/ nb2, | ||
| 1121 | /*.nb3 =*/ nb3, | ||
| 1122 | }; | ||
| 1123 | |||
| 1124 | const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 1125 | |||
| 1126 | const int nw0 = (args.ne00t + nth - 1)/nth; | ||
| 1127 | |||
| 1128 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1129 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1130 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1131 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1132 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1133 | |||
| 1134 | ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); | ||
| 1135 | |||
| 1136 | return 1; | ||
| 1137 | } | ||
| 1138 | |||
| 1139 | int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { | ||
| 1140 | ggml_tensor * op = ctx->node(idx); | ||
| 1141 | |||
| 1142 | ggml_metal_library_t lib = ctx->lib; | ||
| 1143 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1144 | |||
| 1145 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1146 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1147 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 1148 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 1149 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1150 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1151 | |||
| 1152 | auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); | ||
| 1153 | |||
| 1154 | const int32_t nk0 = ne0/ggml_blck_size(op->type); | ||
| 1155 | |||
| 1156 | int nth = 32; // SIMD width | ||
| 1157 | |||
| 1158 | while (nth < nk0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 1159 | nth *= 2; | ||
| 1160 | } | ||
| 1161 | |||
| 1162 | int nrptg = 1; | ||
| 1163 | if (nth > nk0) { | ||
| 1164 | nrptg = (nth + nk0 - 1)/nk0; | ||
| 1165 | nth = nk0; | ||
| 1166 | |||
| 1167 | if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 1168 | nrptg--; | ||
| 1169 | } | ||
| 1170 | } | ||
| 1171 | |||
| 1172 | nth = std::min(nth, nk0); | ||
| 1173 | |||
| 1174 | ggml_metal_kargs_set_rows args = { | ||
| 1175 | /*.nk0 =*/ nk0, | ||
| 1176 | /*.ne01 =*/ ne01, | ||
| 1177 | /*.nb01 =*/ nb01, | ||
| 1178 | /*.nb02 =*/ nb02, | ||
| 1179 | /*.nb03 =*/ nb03, | ||
| 1180 | /*.ne11 =*/ ne11, | ||
| 1181 | /*.ne12 =*/ ne12, | ||
| 1182 | /*.nb10 =*/ nb10, | ||
| 1183 | /*.nb11 =*/ nb11, | ||
| 1184 | /*.nb12 =*/ nb12, | ||
| 1185 | /*.nb1 =*/ nb1, | ||
| 1186 | /*.nb2 =*/ nb2, | ||
| 1187 | /*.nb3 =*/ nb3, | ||
| 1188 | }; | ||
| 1189 | |||
| 1190 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1191 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1192 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1193 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1194 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1195 | |||
| 1196 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); | ||
| 1197 | |||
| 1198 | return 1; | ||
| 1199 | } | ||
| 1200 | |||
| 1201 | int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) { | ||
| 1202 | ggml_tensor * op = ctx->node(idx); | ||
| 1203 | |||
| 1204 | ggml_metal_library_t lib = ctx->lib; | ||
| 1205 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1206 | |||
| 1207 | GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); | ||
| 1208 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1209 | GGML_TENSOR_LOCALS(int32_t, ne, op, ne); | ||
| 1210 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1211 | |||
| 1212 | ggml_metal_kargs_diag args = { | ||
| 1213 | /*.ne00 =*/ne00, | ||
| 1214 | /*.ne01 =*/ne01, | ||
| 1215 | /*.ne02 =*/ne02, | ||
| 1216 | /*.ne03 =*/ne03, | ||
| 1217 | /*.nb00 =*/nb00, | ||
| 1218 | /*.nb01 =*/nb01, | ||
| 1219 | /*.nb02 =*/nb02, | ||
| 1220 | /*.nb03 =*/nb03, | ||
| 1221 | /*.ne0 =*/ne0, | ||
| 1222 | /*.ne1 =*/ne1, | ||
| 1223 | /*.ne2 =*/ne2, | ||
| 1224 | /*.ne3 =*/ne3, | ||
| 1225 | /*.nb0 =*/nb0, | ||
| 1226 | /*.nb1 =*/nb1, | ||
| 1227 | /*.nb2 =*/nb2, | ||
| 1228 | /*.nb3 =*/nb3, | ||
| 1229 | }; | ||
| 1230 | |||
| 1231 | auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op); | ||
| 1232 | |||
| 1233 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1234 | ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| 1235 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1236 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2); | ||
| 1237 | |||
| 1238 | ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1); | ||
| 1239 | |||
| 1240 | return 1; | ||
| 1241 | } | ||
| 1242 | |||
| 1243 | int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { | ||
| 1244 | ggml_tensor * op = ctx->node(idx); | ||
| 1245 | |||
| 1246 | ggml_metal_library_t lib = ctx->lib; | ||
| 1247 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1248 | |||
| 1249 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1250 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1251 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 1252 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 1253 | GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 1254 | GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 1255 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1256 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1257 | |||
| 1258 | float scale; | ||
| 1259 | float max_bias; | ||
| 1260 | |||
| 1261 | memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); | ||
| 1262 | memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); | ||
| 1263 | |||
| 1264 | const uint32_t n_head = op->src[0]->ne[2]; | ||
| 1265 | const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||
| 1266 | |||
| 1267 | const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); | ||
| 1268 | const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||
| 1269 | |||
| 1270 | // softmax | ||
| 1271 | |||
| 1272 | ggml_metal_kargs_soft_max args = { | ||
| 1273 | /*.ne00 =*/ ne00, | ||
| 1274 | /*.ne01 =*/ ne01, | ||
| 1275 | /*.ne02 =*/ ne02, | ||
| 1276 | /*.nb01 =*/ nb01, | ||
| 1277 | /*.nb02 =*/ nb02, | ||
| 1278 | /*.nb03 =*/ nb03, | ||
| 1279 | /*.ne11 =*/ ne11, | ||
| 1280 | /*.ne12 =*/ ne12, | ||
| 1281 | /*.ne13 =*/ ne13, | ||
| 1282 | /*.nb11 =*/ nb11, | ||
| 1283 | /*.nb12 =*/ nb12, | ||
| 1284 | /*.nb13 =*/ nb13, | ||
| 1285 | /*.nb1 =*/ nb1, | ||
| 1286 | /*.nb2 =*/ nb2, | ||
| 1287 | /*.nb3 =*/ nb3, | ||
| 1288 | /*.scale =*/ scale, | ||
| 1289 | /*.max_bias =*/ max_bias, | ||
| 1290 | /*.m0 =*/ m0, | ||
| 1291 | /*.m1 =*/ m1, | ||
| 1292 | /*.n_head_log2 =*/ n_head_log2, | ||
| 1293 | }; | ||
| 1294 | |||
| 1295 | auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op); | ||
| 1296 | |||
| 1297 | int nth = 32; // SIMD width | ||
| 1298 | |||
| 1299 | if (ne00%4 == 0) { | ||
| 1300 | while (nth < ne00/4 && nth*ne01*ne02*ne03 < 256) { | ||
| 1301 | nth *= 2; | ||
| 1302 | } | ||
| 1303 | } else { | ||
| 1304 | while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { | ||
| 1305 | nth *= 2; | ||
| 1306 | } | ||
| 1307 | } | ||
| 1308 | |||
| 1309 | const size_t smem = pipeline.smem; | ||
| 1310 | |||
| 1311 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1312 | ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| 1313 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1314 | if (op->src[1]) { | ||
| 1315 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1316 | } else { | ||
| 1317 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 2); | ||
| 1318 | } | ||
| 1319 | if (op->src[2]) { | ||
| 1320 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[2]), 3); | ||
| 1321 | } else { | ||
| 1322 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3); | ||
| 1323 | } | ||
| 1324 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); | ||
| 1325 | |||
| 1326 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 1327 | |||
| 1328 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 1329 | |||
| 1330 | return 1; | ||
| 1331 | } | ||
| 1332 | |||
| 1333 | int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { | ||
| 1334 | ggml_tensor * op = ctx->node(idx); | ||
| 1335 | |||
| 1336 | ggml_metal_library_t lib = ctx->lib; | ||
| 1337 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1338 | |||
| 1339 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1340 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1341 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 1342 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 1343 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1344 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1345 | |||
| 1346 | ggml_metal_kargs_ssm_conv args = { | ||
| 1347 | /*.ne00 =*/ ne00, | ||
| 1348 | /*.ne01 =*/ ne01, | ||
| 1349 | /*.ne02 =*/ ne02, | ||
| 1350 | /*.nb00 =*/ nb00, | ||
| 1351 | /*.nb01 =*/ nb01, | ||
| 1352 | /*.nb02 =*/ nb02, | ||
| 1353 | /*.ne10 =*/ ne10, | ||
| 1354 | /*.ne11 =*/ ne11, | ||
| 1355 | /*.nb10 =*/ nb10, | ||
| 1356 | /*.nb11 =*/ nb11, | ||
| 1357 | /*.ne0 =*/ ne0, | ||
| 1358 | /*.ne1 =*/ ne1, | ||
| 1359 | /*.ne2 =*/ ne2, | ||
| 1360 | /*.nb0 =*/ nb0, | ||
| 1361 | /*.nb1 =*/ nb1, | ||
| 1362 | /*.nb2 =*/ nb2, | ||
| 1363 | }; | ||
| 1364 | |||
| 1365 | // Use batched kernel for prefill (ne1 > 1) to reduce threadgroup dispatch overhead | ||
| 1366 | const bool use_batched = (ne1 > 1); | ||
| 1367 | |||
| 1368 | if (use_batched) { | ||
| 1369 | // Determine the smallest power of 2 that's >= ne1, but <= 256 | ||
| 1370 | int BATCH_SIZE; | ||
| 1371 | if (ne1 > 128) BATCH_SIZE = 256; | ||
| 1372 | else if (ne1 > 64 ) BATCH_SIZE = 128; | ||
| 1373 | else if (ne1 > 32 ) BATCH_SIZE = 64; | ||
| 1374 | else if (ne1 > 16 ) BATCH_SIZE = 32; | ||
| 1375 | else if (ne1 > 8 ) BATCH_SIZE = 16; | ||
| 1376 | else if (ne1 > 4 ) BATCH_SIZE = 8; | ||
| 1377 | else BATCH_SIZE = 2; | ||
| 1378 | |||
| 1379 | auto pipeline = ggml_metal_library_get_pipeline_ssm_conv_batched(lib, op, BATCH_SIZE); | ||
| 1380 | |||
| 1381 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1382 | ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| 1383 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1384 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1385 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1386 | |||
| 1387 | // Dispatch: ne01 rows, ceil(ne1/BATCH_SIZE) token batches, ne02 sequences | ||
| 1388 | // Each threadgroup has BATCH_SIZE threads, each handling one token | ||
| 1389 | const int n_token_batches = (ne1 + BATCH_SIZE - 1) / BATCH_SIZE; | ||
| 1390 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, n_token_batches, ne02, BATCH_SIZE, 1, 1); | ||
| 1391 | } else { | ||
| 1392 | auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op); | ||
| 1393 | |||
| 1394 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1395 | ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| 1396 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1397 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1398 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1399 | |||
| 1400 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); | ||
| 1401 | } | ||
| 1402 | |||
| 1403 | return 1; | ||
| 1404 | } | ||
| 1405 | |||
| 1406 | int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { | ||
| 1407 | ggml_tensor * op = ctx->node(idx); | ||
| 1408 | |||
| 1409 | ggml_metal_library_t lib = ctx->lib; | ||
| 1410 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1411 | |||
| 1412 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1413 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1414 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 1415 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 1416 | GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 1417 | GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 1418 | GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); | ||
| 1419 | GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); | ||
| 1420 | GGML_TENSOR_LOCALS( int32_t, ne4, op->src[4], ne); | ||
| 1421 | GGML_TENSOR_LOCALS(uint64_t, nb4, op->src[4], nb); | ||
| 1422 | GGML_TENSOR_LOCALS( int32_t, ne5, op->src[5], ne); | ||
| 1423 | GGML_TENSOR_LOCALS(uint64_t, nb5, op->src[5], nb); | ||
| 1424 | GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne); | ||
| 1425 | GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb); | ||
| 1426 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1427 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1428 | |||
| 1429 | const ggml_tensor * src3 = op->src[3]; | ||
| 1430 | const ggml_tensor * src4 = op->src[4]; | ||
| 1431 | const ggml_tensor * src5 = op->src[5]; | ||
| 1432 | const ggml_tensor * src6 = op->src[6]; | ||
| 1433 | |||
| 1434 | GGML_ASSERT(src3); | ||
| 1435 | GGML_ASSERT(src4); | ||
| 1436 | GGML_ASSERT(src5); | ||
| 1437 | GGML_ASSERT(src6); | ||
| 1438 | |||
| 1439 | const int64_t d_state = ne00; | ||
| 1440 | const int64_t d_inner = ne01; | ||
| 1441 | const int64_t n_head = ne02; | ||
| 1442 | const int64_t n_group = ne41; | ||
| 1443 | const int64_t n_seq_tokens = ne12; | ||
| 1444 | const int64_t n_seqs = ne13; | ||
| 1445 | |||
| 1446 | ggml_metal_kargs_ssm_scan args = { | ||
| 1447 | /*.d_state =*/ d_state, | ||
| 1448 | /*.d_inner =*/ d_inner, | ||
| 1449 | /*.n_head =*/ n_head, | ||
| 1450 | /*.n_group =*/ n_group, | ||
| 1451 | /*.n_seq_tokens =*/ n_seq_tokens, | ||
| 1452 | /*.n_seqs =*/ n_seqs, | ||
| 1453 | /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), | ||
| 1454 | /*.nb00 =*/ nb00, | ||
| 1455 | /*.nb01 =*/ nb01, | ||
| 1456 | /*.nb02 =*/ nb02, | ||
| 1457 | /*.nb03 =*/ nb03, | ||
| 1458 | /*.nb10 =*/ nb10, | ||
| 1459 | /*.nb11 =*/ nb11, | ||
| 1460 | /*.nb12 =*/ nb12, | ||
| 1461 | /*.ns12 =*/ nb12/nb10, | ||
| 1462 | /*.nb13 =*/ nb13, | ||
| 1463 | /*.nb20 =*/ nb20, | ||
| 1464 | /*.nb21 =*/ nb21, | ||
| 1465 | /*.ns21 =*/ nb21/nb20, | ||
| 1466 | /*.nb22 =*/ nb22, | ||
| 1467 | /*.ne30 =*/ ne30, | ||
| 1468 | /*.nb31 =*/ nb31, | ||
| 1469 | /*.nb41 =*/ nb41, | ||
| 1470 | /*.nb42 =*/ nb42, | ||
| 1471 | /*.ns42 =*/ nb42/nb40, | ||
| 1472 | /*.nb43 =*/ nb43, | ||
| 1473 | /*.nb51 =*/ nb51, | ||
| 1474 | /*.nb52 =*/ nb52, | ||
| 1475 | /*.ns52 =*/ nb52/nb50, | ||
| 1476 | /*.nb53 =*/ nb53, | ||
| 1477 | /*.nb0 =*/ nb0, | ||
| 1478 | }; | ||
| 1479 | |||
| 1480 | auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); | ||
| 1481 | |||
| 1482 | GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 1483 | |||
| 1484 | const size_t smem = pipeline.smem; | ||
| 1485 | |||
| 1486 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1487 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1488 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1489 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1490 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); | ||
| 1491 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), 4); | ||
| 1492 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), 5); | ||
| 1493 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), 6); | ||
| 1494 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), 7); | ||
| 1495 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 8); | ||
| 1496 | |||
| 1497 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 1498 | |||
| 1499 | ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); | ||
| 1500 | |||
| 1501 | return 1; | ||
| 1502 | } | ||
| 1503 | |||
| 1504 | int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { | ||
| 1505 | ggml_tensor * op = ctx->node(idx); | ||
| 1506 | |||
| 1507 | ggml_metal_library_t lib = ctx->lib; | ||
| 1508 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1509 | |||
| 1510 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1511 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1512 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1513 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1514 | |||
| 1515 | const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1]; | ||
| 1516 | const int64_t T = op->src[0]->ne[2]; | ||
| 1517 | const int64_t C = op->ne[0]; | ||
| 1518 | const int64_t H = op->src[0]->ne[1]; | ||
| 1519 | |||
| 1520 | auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); | ||
| 1521 | |||
| 1522 | int ida = 0; | ||
| 1523 | |||
| 1524 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1525 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); | ||
| 1526 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); | ||
| 1527 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); | ||
| 1528 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); | ||
| 1529 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); | ||
| 1530 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); | ||
| 1531 | if (op->op == GGML_OP_RWKV_WKV7) { | ||
| 1532 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++); | ||
| 1533 | } | ||
| 1534 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); | ||
| 1535 | ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); | ||
| 1536 | ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); | ||
| 1537 | ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); | ||
| 1538 | ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); | ||
| 1539 | |||
| 1540 | ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); | ||
| 1541 | |||
| 1542 | return 1; | ||
| 1543 | } | ||
| 1544 | |||
| 1545 | int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { | ||
| 1546 | ggml_tensor * op = ctx->node(idx); | ||
| 1547 | |||
| 1548 | ggml_metal_library_t lib = ctx->lib; | ||
| 1549 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1550 | |||
| 1551 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1552 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1553 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 1554 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 1555 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1556 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1557 | |||
| 1558 | ggml_metal_kargs_solve_tri args = { | ||
| 1559 | /*.ne00 =*/ ne00, | ||
| 1560 | /*.ne01 =*/ ne01, | ||
| 1561 | /*.ne02 =*/ ne02, | ||
| 1562 | /*.ne03 =*/ ne03, | ||
| 1563 | /*.nb00 =*/ nb00, | ||
| 1564 | /*.nb01 =*/ nb01, | ||
| 1565 | /*.nb02 =*/ nb02, | ||
| 1566 | /*.nb03 =*/ nb03, | ||
| 1567 | /*.ne10 =*/ ne10, | ||
| 1568 | /*.ne11 =*/ ne11, | ||
| 1569 | /*.ne12 =*/ ne12, | ||
| 1570 | /*.ne13 =*/ ne13, | ||
| 1571 | /*.nb10 =*/ nb10, | ||
| 1572 | /*.nb11 =*/ nb11, | ||
| 1573 | /*.nb12 =*/ nb12, | ||
| 1574 | /*.nb13 =*/ nb13, | ||
| 1575 | /*.ne0 =*/ ne0, | ||
| 1576 | /*.ne1 =*/ ne1, | ||
| 1577 | /*.ne2 =*/ ne2, | ||
| 1578 | /*.ne3 =*/ ne3, | ||
| 1579 | /*.nb0 =*/ nb0, | ||
| 1580 | /*.nb1 =*/ nb1, | ||
| 1581 | /*.nb2 =*/ nb2, | ||
| 1582 | /*.nb3 =*/ nb3, | ||
| 1583 | }; | ||
| 1584 | |||
| 1585 | auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); | ||
| 1586 | |||
| 1587 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1588 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1589 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1590 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1591 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1592 | |||
| 1593 | const int nsg = pipeline.nsg; | ||
| 1594 | |||
| 1595 | ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0); | ||
| 1596 | |||
| 1597 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1); | ||
| 1598 | |||
| 1599 | return 1; | ||
| 1600 | } | ||
| 1601 | |||
| 1602 | int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { | ||
| 1603 | ggml_tensor * op = ctx->node(idx); | ||
| 1604 | |||
| 1605 | ggml_metal_library_t lib = ctx->lib; | ||
| 1606 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1607 | |||
| 1608 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1609 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1610 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1611 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1612 | |||
| 1613 | auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); | ||
| 1614 | |||
| 1615 | GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); | ||
| 1616 | |||
| 1617 | int64_t nk0 = ne00; | ||
| 1618 | if (ggml_is_quantized(op->src[0]->type)) { | ||
| 1619 | nk0 = ne00/16; | ||
| 1620 | } else if (ggml_is_quantized(op->type)) { | ||
| 1621 | nk0 = ne00/ggml_blck_size(op->type); | ||
| 1622 | } | ||
| 1623 | |||
| 1624 | int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 1625 | |||
| 1626 | // when rows are small, we can batch them together in a single threadgroup | ||
| 1627 | int nrptg = 1; | ||
| 1628 | |||
| 1629 | // TODO: relax this constraint in the future | ||
| 1630 | if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { | ||
| 1631 | if (nth > nk0) { | ||
| 1632 | nrptg = (nth + nk0 - 1)/nk0; | ||
| 1633 | nth = nk0; | ||
| 1634 | |||
| 1635 | if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 1636 | nrptg--; | ||
| 1637 | } | ||
| 1638 | } | ||
| 1639 | } | ||
| 1640 | |||
| 1641 | nth = std::min<int>(nth, nk0); | ||
| 1642 | |||
| 1643 | ggml_metal_kargs_cpy args = { | ||
| 1644 | /*.nk0 =*/ nk0, | ||
| 1645 | /*.ne00 =*/ ne00, | ||
| 1646 | /*.ne01 =*/ ne01, | ||
| 1647 | /*.ne02 =*/ ne02, | ||
| 1648 | /*.ne03 =*/ ne03, | ||
| 1649 | /*.nb00 =*/ nb00, | ||
| 1650 | /*.nb01 =*/ nb01, | ||
| 1651 | /*.nb02 =*/ nb02, | ||
| 1652 | /*.nb03 =*/ nb03, | ||
| 1653 | /*.ne0 =*/ ne0, | ||
| 1654 | /*.ne1 =*/ ne1, | ||
| 1655 | /*.ne2 =*/ ne2, | ||
| 1656 | /*.ne3 =*/ ne3, | ||
| 1657 | /*.nb0 =*/ nb0, | ||
| 1658 | /*.nb1 =*/ nb1, | ||
| 1659 | /*.nb2 =*/ nb2, | ||
| 1660 | /*.nb3 =*/ nb3, | ||
| 1661 | }; | ||
| 1662 | |||
| 1663 | const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; | ||
| 1664 | |||
| 1665 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1666 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1667 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1668 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 1669 | |||
| 1670 | ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); | ||
| 1671 | |||
| 1672 | return 1; | ||
| 1673 | } | ||
| 1674 | |||
| 1675 | int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) { | ||
| 1676 | ggml_tensor * op = ctx->node(idx); | ||
| 1677 | |||
| 1678 | ggml_metal_library_t lib = ctx->lib; | ||
| 1679 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1680 | |||
| 1681 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1682 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1683 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1684 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1685 | |||
| 1686 | const int32_t * opts = op->op_params; | ||
| 1687 | ggml_op_pool op_pool = (ggml_op_pool) opts[0]; | ||
| 1688 | |||
| 1689 | const int32_t k0 = opts[1]; | ||
| 1690 | const int32_t s0 = opts[2]; | ||
| 1691 | const int32_t p0 = opts[3]; | ||
| 1692 | |||
| 1693 | const int64_t IW = op->src[0]->ne[0]; | ||
| 1694 | const int64_t OW = op->ne[0]; | ||
| 1695 | |||
| 1696 | const int64_t np = ggml_nelements(op); | ||
| 1697 | |||
| 1698 | ggml_metal_kargs_pool_1d args_pool_1d = { | ||
| 1699 | /* .k0 = */ k0, | ||
| 1700 | /* .s0 = */ s0, | ||
| 1701 | /* .p0 = */ p0, | ||
| 1702 | /* .IW = */ IW, | ||
| 1703 | /* .OW = */ OW, | ||
| 1704 | /* .np = */ np | ||
| 1705 | }; | ||
| 1706 | |||
| 1707 | auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool); | ||
| 1708 | |||
| 1709 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); | ||
| 1710 | const int ntg = (np + nth - 1) / nth; | ||
| 1711 | |||
| 1712 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1713 | ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0); | ||
| 1714 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1715 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 1716 | |||
| 1717 | ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); | ||
| 1718 | |||
| 1719 | return 1; | ||
| 1720 | } | ||
| 1721 | |||
| 1722 | |||
| 1723 | int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { | ||
| 1724 | ggml_tensor * op = ctx->node(idx); | ||
| 1725 | |||
| 1726 | ggml_metal_library_t lib = ctx->lib; | ||
| 1727 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1728 | |||
| 1729 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1730 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1731 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1732 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1733 | |||
| 1734 | const int32_t * opts = op->op_params; | ||
| 1735 | ggml_op_pool op_pool = (ggml_op_pool) opts[0]; | ||
| 1736 | |||
| 1737 | const int32_t k0 = opts[1]; | ||
| 1738 | const int32_t k1 = opts[2]; | ||
| 1739 | const int32_t s0 = opts[3]; | ||
| 1740 | const int32_t s1 = opts[4]; | ||
| 1741 | const int32_t p0 = opts[5]; | ||
| 1742 | const int32_t p1 = opts[6]; | ||
| 1743 | |||
| 1744 | const int64_t IH = op->src[0]->ne[1]; | ||
| 1745 | const int64_t IW = op->src[0]->ne[0]; | ||
| 1746 | |||
| 1747 | const int64_t N = op->ne[3]; | ||
| 1748 | const int64_t OC = op->ne[2]; | ||
| 1749 | const int64_t OH = op->ne[1]; | ||
| 1750 | const int64_t OW = op->ne[0]; | ||
| 1751 | |||
| 1752 | const int64_t np = N * OC * OH * OW; | ||
| 1753 | |||
| 1754 | ggml_metal_kargs_pool_2d args_pool_2d = { | ||
| 1755 | /* .k0 = */ k0, | ||
| 1756 | /* .k1 = */ k1, | ||
| 1757 | /* .s0 = */ s0, | ||
| 1758 | /* .s1 = */ s1, | ||
| 1759 | /* .p0 = */ p0, | ||
| 1760 | /* .p1 = */ p1, | ||
| 1761 | /* .IH = */ IH, | ||
| 1762 | /* .IW = */ IW, | ||
| 1763 | /* .OH = */ OH, | ||
| 1764 | /* .OW = */ OW, | ||
| 1765 | /* .np = */ np | ||
| 1766 | }; | ||
| 1767 | |||
| 1768 | auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool); | ||
| 1769 | |||
| 1770 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); | ||
| 1771 | const int ntg = (np + nth - 1) / nth; | ||
| 1772 | |||
| 1773 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1774 | ggml_metal_encoder_set_bytes (enc, &args_pool_2d, sizeof(args_pool_2d), 0); | ||
| 1775 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1776 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 1777 | |||
| 1778 | ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); | ||
| 1779 | |||
| 1780 | return 1; | ||
| 1781 | } | ||
| 1782 | |||
| 1783 | int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { | ||
| 1784 | ggml_tensor * op = ctx->node(idx); | ||
| 1785 | |||
| 1786 | ggml_metal_library_t lib = ctx->lib; | ||
| 1787 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 1788 | |||
| 1789 | const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); | ||
| 1790 | |||
| 1791 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 1792 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 1793 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 1794 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 1795 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 1796 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 1797 | |||
| 1798 | GGML_ASSERT(ne00 == ne10); | ||
| 1799 | |||
| 1800 | GGML_ASSERT(ne12 % ne02 == 0); | ||
| 1801 | GGML_ASSERT(ne13 % ne03 == 0); | ||
| 1802 | |||
| 1803 | const int16_t r2 = ne12/ne02; | ||
| 1804 | const int16_t r3 = ne13/ne03; | ||
| 1805 | |||
| 1806 | // find the break-even point where the matrix-matrix kernel becomes more efficient compared | ||
| 1807 | // to the matrix-vector kernel | ||
| 1808 | const int ne11_mm_min = 8; | ||
| 1809 | |||
| 1810 | // first try to use small-batch mat-mv kernels | ||
| 1811 | // these should be efficient for BS [2, ~8] | ||
| 1812 | if (op->src[1]->type == GGML_TYPE_F32 && (ne00%128 == 0) && | ||
| 1813 | ( | ||
| 1814 | ( | ||
| 1815 | ( | ||
| 1816 | op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function | ||
| 1817 | op->src[0]->type == GGML_TYPE_F16 || | ||
| 1818 | op->src[0]->type == GGML_TYPE_Q4_0 || | ||
| 1819 | op->src[0]->type == GGML_TYPE_Q4_1 || | ||
| 1820 | op->src[0]->type == GGML_TYPE_Q5_0 || | ||
| 1821 | op->src[0]->type == GGML_TYPE_Q5_1 || | ||
| 1822 | op->src[0]->type == GGML_TYPE_Q8_0 || | ||
| 1823 | op->src[0]->type == GGML_TYPE_MXFP4 || | ||
| 1824 | op->src[0]->type == GGML_TYPE_IQ4_NL || | ||
| 1825 | false) && (ne11 >= 2 && ne11 <= 8) | ||
| 1826 | ) || | ||
| 1827 | ( | ||
| 1828 | ( | ||
| 1829 | op->src[0]->type == GGML_TYPE_Q4_K || | ||
| 1830 | op->src[0]->type == GGML_TYPE_Q5_K || | ||
| 1831 | op->src[0]->type == GGML_TYPE_Q6_K || | ||
| 1832 | false) && (ne11 >= 4 && ne11 <= 8) | ||
| 1833 | ) | ||
| 1834 | ) | ||
| 1835 | ) { | ||
| 1836 | // TODO: determine the optimal parameters based on grid utilization | ||
| 1837 | // I still don't know why we should not always use the maximum available threads: | ||
| 1838 | // | ||
| 1839 | // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 | ||
| 1840 | // | ||
| 1841 | // my current hypothesis is that the work grid is not evenly divisible for different nsg | ||
| 1842 | // values and there can be some tail effects when nsg is high. need to confirm this | ||
| 1843 | // | ||
| 1844 | const int nsg = 2; // num simdgroups per threadgroup | ||
| 1845 | |||
| 1846 | // num threads along row per simdgroup | ||
| 1847 | int16_t nxpsg = 0; | ||
| 1848 | if (ne00 % 256 == 0 && ne11 < 3) { | ||
| 1849 | nxpsg = 16; | ||
| 1850 | } else if (ne00 % 128 == 0) { | ||
| 1851 | nxpsg = 8; | ||
| 1852 | } else { | ||
| 1853 | nxpsg = 4; | ||
| 1854 | } | ||
| 1855 | |||
| 1856 | const int16_t nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) | ||
| 1857 | const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup | ||
| 1858 | int16_t r1ptg = 4; // num src1 rows per threadgroup | ||
| 1859 | |||
| 1860 | // note: not sure how optimal are those across all different hardware. there might be someting cleverer | ||
| 1861 | switch (ne11) { | ||
| 1862 | case 2: | ||
| 1863 | r1ptg = 2; break; | ||
| 1864 | case 3: | ||
| 1865 | case 6: | ||
| 1866 | r1ptg = 3; break; | ||
| 1867 | case 4: | ||
| 1868 | case 7: | ||
| 1869 | case 8: | ||
| 1870 | r1ptg = 4; break; | ||
| 1871 | case 5: | ||
| 1872 | r1ptg = 5; break; | ||
| 1873 | default: | ||
| 1874 | GGML_ABORT("unsupported ne11"); | ||
| 1875 | }; | ||
| 1876 | |||
| 1877 | auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); | ||
| 1878 | |||
| 1879 | ggml_metal_kargs_mul_mv_ext args = { | ||
| 1880 | /*.ne00 =*/ ne00, | ||
| 1881 | /*.ne01 =*/ ne01, | ||
| 1882 | /*.ne02 =*/ ne02, | ||
| 1883 | /*.nb00 =*/ nb00, | ||
| 1884 | /*.nb01 =*/ nb01, | ||
| 1885 | /*.nb02 =*/ nb02, | ||
| 1886 | /*.nb03 =*/ nb03, | ||
| 1887 | /*.ne10 =*/ ne10, | ||
| 1888 | /*.ne11 =*/ ne11, | ||
| 1889 | /*.ne12 =*/ ne12, | ||
| 1890 | /*.nb10 =*/ nb10, | ||
| 1891 | /*.nb11 =*/ nb11, | ||
| 1892 | /*.nb12 =*/ nb12, | ||
| 1893 | /*.nb13 =*/ nb13, | ||
| 1894 | /*.ne0 =*/ ne0, | ||
| 1895 | /*.ne1 =*/ ne1, | ||
| 1896 | /*.r2 =*/ r2, | ||
| 1897 | /*.r3 =*/ r3, | ||
| 1898 | }; | ||
| 1899 | |||
| 1900 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1901 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1902 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1903 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1904 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1905 | |||
| 1906 | ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + r0ptg - 1)/r0ptg), ((ne11 + r1ptg - 1)/r1ptg), ne12*ne13, 32, nsg, 1); | ||
| 1907 | } else if ( | ||
| 1908 | !ggml_is_transposed(op->src[0]) && | ||
| 1909 | !ggml_is_transposed(op->src[1]) && | ||
| 1910 | // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs | ||
| 1911 | // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel | ||
| 1912 | props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { | ||
| 1913 | //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||
| 1914 | |||
| 1915 | // some Metal matrix data types require aligned pointers | ||
| 1916 | // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) | ||
| 1917 | //switch (op->src[0]->type) { | ||
| 1918 | // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; | ||
| 1919 | // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; | ||
| 1920 | // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; | ||
| 1921 | // default: break; | ||
| 1922 | //} | ||
| 1923 | |||
| 1924 | auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op); | ||
| 1925 | |||
| 1926 | ggml_metal_kargs_mul_mm args = { | ||
| 1927 | /*.ne00 =*/ ne00, | ||
| 1928 | /*.ne02 =*/ ne02, | ||
| 1929 | /*.nb01 =*/ nb01, | ||
| 1930 | /*.nb02 =*/ nb02, | ||
| 1931 | /*.nb03 =*/ nb03, | ||
| 1932 | /*.ne12 =*/ ne12, | ||
| 1933 | /*.nb10 =*/ nb10, | ||
| 1934 | /*.nb11 =*/ nb11, | ||
| 1935 | /*.nb12 =*/ nb12, | ||
| 1936 | /*.nb13 =*/ nb13, | ||
| 1937 | /*.ne0 =*/ ne0, | ||
| 1938 | /*.ne1 =*/ ne1, | ||
| 1939 | /*.r2 =*/ r2, | ||
| 1940 | /*.r3 =*/ r3, | ||
| 1941 | }; | ||
| 1942 | |||
| 1943 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1944 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1945 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1946 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1947 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1948 | |||
| 1949 | const size_t smem = pipeline.smem; | ||
| 1950 | |||
| 1951 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 1952 | ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); | ||
| 1953 | } else { | ||
| 1954 | auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); | ||
| 1955 | |||
| 1956 | const int nr0 = pipeline.nr0; | ||
| 1957 | const int nr1 = pipeline.nr1; | ||
| 1958 | const int nsg = pipeline.nsg; | ||
| 1959 | |||
| 1960 | const size_t smem = pipeline.smem; | ||
| 1961 | |||
| 1962 | ggml_metal_kargs_mul_mv args = { | ||
| 1963 | /*.ne00 =*/ ne00, | ||
| 1964 | /*.ne01 =*/ ne01, | ||
| 1965 | /*.ne02 =*/ ne02, | ||
| 1966 | /*.nb00 =*/ nb00, | ||
| 1967 | /*.nb01 =*/ nb01, | ||
| 1968 | /*.nb02 =*/ nb02, | ||
| 1969 | /*.nb03 =*/ nb03, | ||
| 1970 | /*.ne10 =*/ ne10, | ||
| 1971 | /*.ne11 =*/ ne11, | ||
| 1972 | /*.ne12 =*/ ne12, | ||
| 1973 | /*.nb10 =*/ nb10, | ||
| 1974 | /*.nb11 =*/ nb11, | ||
| 1975 | /*.nb12 =*/ nb12, | ||
| 1976 | /*.nb13 =*/ nb13, | ||
| 1977 | /*.ne0 =*/ ne0, | ||
| 1978 | /*.ne1 =*/ ne1, | ||
| 1979 | /*.nr0 =*/ nr0, | ||
| 1980 | /*.r2 =*/ r2, | ||
| 1981 | /*.r3 =*/ r3, | ||
| 1982 | }; | ||
| 1983 | |||
| 1984 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 1985 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 1986 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 1987 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 1988 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 1989 | |||
| 1990 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 1991 | |||
| 1992 | if (op->src[0]->type == GGML_TYPE_F32 || | ||
| 1993 | op->src[0]->type == GGML_TYPE_F16 || | ||
| 1994 | op->src[0]->type == GGML_TYPE_BF16 || | ||
| 1995 | op->src[0]->type == GGML_TYPE_Q8_0) { | ||
| 1996 | ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); | ||
| 1997 | } else { | ||
| 1998 | ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); | ||
| 1999 | } | ||
| 2000 | } | ||
| 2001 | |||
| 2002 | return 1; | ||
| 2003 | } | ||
| 2004 | |||
| 2005 | size_t ggml_metal_op_mul_mat_id_extra_tpe(const ggml_tensor * op) { | ||
| 2006 | assert(op->op == GGML_OP_MUL_MAT_ID); | ||
| 2007 | |||
| 2008 | const int64_t ne02 = op->src[0]->ne[2]; // n_expert | ||
| 2009 | |||
| 2010 | return ggml_type_size(GGML_TYPE_I32)*ne02; | ||
| 2011 | } | ||
| 2012 | |||
| 2013 | size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) { | ||
| 2014 | assert(op->op == GGML_OP_MUL_MAT_ID); | ||
| 2015 | |||
| 2016 | const int64_t ne02 = op->src[0]->ne[2]; // n_expert | ||
| 2017 | const int64_t ne21 = op->src[2]->ne[1]; // n_token | ||
| 2018 | |||
| 2019 | return ggml_type_size(GGML_TYPE_I32)*ne02*ne21; | ||
| 2020 | } | ||
| 2021 | |||
| 2022 | int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { | ||
| 2023 | ggml_tensor * op = ctx->node(idx); | ||
| 2024 | |||
| 2025 | ggml_metal_library_t lib = ctx->lib; | ||
| 2026 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 2027 | |||
| 2028 | const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); | ||
| 2029 | |||
| 2030 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2031 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2032 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 2033 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 2034 | GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 2035 | GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 2036 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 2037 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 2038 | |||
| 2039 | // src2 = ids | ||
| 2040 | GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); | ||
| 2041 | |||
| 2042 | GGML_ASSERT(!ggml_is_transposed(op->src[0])); | ||
| 2043 | GGML_ASSERT(!ggml_is_transposed(op->src[1])); | ||
| 2044 | |||
| 2045 | GGML_ASSERT(ne03 == 1); | ||
| 2046 | GGML_ASSERT(ne13 == 1); | ||
| 2047 | |||
| 2048 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 2049 | ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); | ||
| 2050 | ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); | ||
| 2051 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 2052 | |||
| 2053 | const uint32_t r2 = 1; | ||
| 2054 | const uint32_t r3 = 1; | ||
| 2055 | |||
| 2056 | // find the break-even point where the matrix-matrix kernel becomes more efficient compared | ||
| 2057 | // to the matrix-vector kernel | ||
| 2058 | // ne20 = n_used_experts | ||
| 2059 | // ne21 = n_rows (batch size) | ||
| 2060 | const int ne21_mm_id_min = 32; | ||
| 2061 | |||
| 2062 | if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) { | ||
| 2063 | // some Metal matrix data types require aligned pointers | ||
| 2064 | // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) | ||
| 2065 | //switch (op->src[0]->type) { | ||
| 2066 | // case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; | ||
| 2067 | // case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break; | ||
| 2068 | // case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break; | ||
| 2069 | // default: break; | ||
| 2070 | //} | ||
| 2071 | |||
| 2072 | // extra buffers for intermediate id mapping | ||
| 2073 | ggml_metal_buffer_id bid_tpe = bid_dst; | ||
| 2074 | bid_tpe.offs += ggml_nbytes(op); | ||
| 2075 | |||
| 2076 | ggml_metal_buffer_id bid_ids = bid_tpe; | ||
| 2077 | bid_ids.offs += ggml_metal_op_mul_mat_id_extra_tpe(op); | ||
| 2078 | |||
| 2079 | { | ||
| 2080 | ggml_metal_kargs_mul_mm_id_map0 args = { | ||
| 2081 | ne02, | ||
| 2082 | ne10, | ||
| 2083 | ne11, // n_expert_used (bcast) | ||
| 2084 | nb11, | ||
| 2085 | nb12, | ||
| 2086 | ne21, // n_tokens | ||
| 2087 | ne20, // n_expert_used | ||
| 2088 | nb21, | ||
| 2089 | }; | ||
| 2090 | |||
| 2091 | auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20); | ||
| 2092 | |||
| 2093 | const size_t smem = pipeline.smem; | ||
| 2094 | |||
| 2095 | GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 2096 | |||
| 2097 | GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); | ||
| 2098 | |||
| 2099 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 2100 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 2101 | ggml_metal_encoder_set_buffer (enc, bid_src2, 1); | ||
| 2102 | ggml_metal_encoder_set_buffer (enc, bid_tpe, 2); | ||
| 2103 | ggml_metal_encoder_set_buffer (enc, bid_ids, 3); | ||
| 2104 | |||
| 2105 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 2106 | |||
| 2107 | ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, ne02, 1, 1); | ||
| 2108 | } | ||
| 2109 | |||
| 2110 | // this barrier is always needed because the next kernel has to wait for the id maps to be computed | ||
| 2111 | ggml_metal_op_concurrency_reset(ctx); | ||
| 2112 | |||
| 2113 | { | ||
| 2114 | auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op); | ||
| 2115 | |||
| 2116 | ggml_metal_kargs_mul_mm_id args = { | ||
| 2117 | /*.ne00 =*/ ne00, | ||
| 2118 | /*.ne02 =*/ ne02, | ||
| 2119 | /*.nb01 =*/ nb01, | ||
| 2120 | /*.nb02 =*/ nb02, | ||
| 2121 | /*.nb03 =*/ nb03, | ||
| 2122 | /*.ne11 =*/ ne11, // n_expert_used (bcast) | ||
| 2123 | /*.nb10 =*/ nb10, | ||
| 2124 | /*.nb11 =*/ nb11, | ||
| 2125 | /*.nb12 =*/ nb12, | ||
| 2126 | /*.nb13 =*/ nb13, | ||
| 2127 | /*.ne20 =*/ ne20, // n_expert_used | ||
| 2128 | /*.ne21 =*/ ne21, // n_tokens | ||
| 2129 | /*.ne0 =*/ ne0, | ||
| 2130 | /*.ne1 =*/ ne1, | ||
| 2131 | /*.r2 =*/ r2, | ||
| 2132 | /*.r3 =*/ r3, | ||
| 2133 | }; | ||
| 2134 | |||
| 2135 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 2136 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 2137 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 2138 | ggml_metal_encoder_set_buffer (enc, bid_src1, 2); | ||
| 2139 | ggml_metal_encoder_set_buffer (enc, bid_tpe, 3); | ||
| 2140 | ggml_metal_encoder_set_buffer (enc, bid_ids, 4); | ||
| 2141 | ggml_metal_encoder_set_buffer (enc, bid_dst, 5); | ||
| 2142 | |||
| 2143 | const size_t smem = pipeline.smem; | ||
| 2144 | |||
| 2145 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 2146 | |||
| 2147 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); | ||
| 2148 | } | ||
| 2149 | } else { | ||
| 2150 | auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); | ||
| 2151 | |||
| 2152 | const int nr0 = pipeline.nr0; | ||
| 2153 | const int nr1 = pipeline.nr1; | ||
| 2154 | const int nsg = pipeline.nsg; | ||
| 2155 | |||
| 2156 | const size_t smem = pipeline.smem; | ||
| 2157 | |||
| 2158 | ggml_metal_kargs_mul_mv_id args = { | ||
| 2159 | /*.nei0 =*/ ne20, | ||
| 2160 | /*.nei1 =*/ ne21, | ||
| 2161 | /*.nbi1 =*/ nb21, | ||
| 2162 | /*.ne00 =*/ ne00, | ||
| 2163 | /*.ne01 =*/ ne01, | ||
| 2164 | /*.ne02 =*/ ne02, | ||
| 2165 | /*.nb00 =*/ nb00, | ||
| 2166 | /*.nb01 =*/ nb01, | ||
| 2167 | /*.nb02 =*/ nb02, | ||
| 2168 | /*.ne10 =*/ ne10, | ||
| 2169 | /*.ne11 =*/ ne11, | ||
| 2170 | /*.ne12 =*/ ne12, | ||
| 2171 | /*.ne13 =*/ ne13, | ||
| 2172 | /*.nb10 =*/ nb10, | ||
| 2173 | /*.nb11 =*/ nb11, | ||
| 2174 | /*.nb12 =*/ nb12, | ||
| 2175 | /*.ne0 =*/ ne0, | ||
| 2176 | /*.ne1 =*/ ne1, | ||
| 2177 | /*.nb1 =*/ nb1, | ||
| 2178 | /*.nr0 =*/ nr0, | ||
| 2179 | }; | ||
| 2180 | |||
| 2181 | if (ggml_is_quantized(op->src[0]->type)) { | ||
| 2182 | GGML_ASSERT(ne00 >= nsg*nr0); | ||
| 2183 | } | ||
| 2184 | |||
| 2185 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 2186 | ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| 2187 | ggml_metal_encoder_set_buffer(enc, bid_src0, 1); | ||
| 2188 | ggml_metal_encoder_set_buffer(enc, bid_src1, 2); | ||
| 2189 | ggml_metal_encoder_set_buffer(enc, bid_dst, 3); | ||
| 2190 | ggml_metal_encoder_set_buffer(enc, bid_src2, 4); | ||
| 2191 | |||
| 2192 | const int64_t _ne1 = 1; | ||
| 2193 | const int64_t ne123 = ne20*ne21; | ||
| 2194 | |||
| 2195 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 2196 | |||
| 2197 | if (op->src[0]->type == GGML_TYPE_F32 || | ||
| 2198 | op->src[0]->type == GGML_TYPE_F16 || | ||
| 2199 | op->src[0]->type == GGML_TYPE_BF16 || | ||
| 2200 | op->src[0]->type == GGML_TYPE_Q8_0) { | ||
| 2201 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); | ||
| 2202 | } else { | ||
| 2203 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); | ||
| 2204 | } | ||
| 2205 | } | ||
| 2206 | |||
| 2207 | return 1; | ||
| 2208 | } | ||
| 2209 | |||
| 2210 | int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) { | ||
| 2211 | ggml_tensor * op = ctx->node(idx); | ||
| 2212 | |||
| 2213 | ggml_metal_library_t lib = ctx->lib; | ||
| 2214 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 2215 | |||
| 2216 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2217 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2218 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 2219 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 2220 | GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 2221 | GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 2222 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 2223 | |||
| 2224 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); | ||
| 2225 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 2226 | GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); | ||
| 2227 | GGML_ASSERT(op->type == GGML_TYPE_F32); | ||
| 2228 | |||
| 2229 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 2230 | |||
| 2231 | ggml_metal_kargs_add_id args = { | ||
| 2232 | /*.ne0 =*/ ne0, | ||
| 2233 | /*.ne1 =*/ ne1, | ||
| 2234 | /*.nb01 =*/ nb01, | ||
| 2235 | /*.nb02 =*/ nb02, | ||
| 2236 | /*.nb11 =*/ nb11, | ||
| 2237 | /*.nb21 =*/ nb21, | ||
| 2238 | }; | ||
| 2239 | |||
| 2240 | auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID); | ||
| 2241 | |||
| 2242 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 2243 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 2244 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 2245 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 2246 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); | ||
| 2247 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); | ||
| 2248 | |||
| 2249 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); | ||
| 2250 | |||
| 2251 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, 1, nth, 1, 1); | ||
| 2252 | |||
| 2253 | return 1; | ||
| 2254 | } | ||
| 2255 | |||
| 2256 | bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { | ||
| 2257 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 2258 | |||
| 2259 | const int64_t ne00 = op->src[0]->ne[0]; // head size | ||
| 2260 | const int64_t ne01 = op->src[0]->ne[1]; // batch size | ||
| 2261 | |||
| 2262 | // use vec kernel if the batch size is small and if the head size is supported | ||
| 2263 | return (ne01 < 20) && (ne00 % 32 == 0); | ||
| 2264 | } | ||
| 2265 | |||
| 2266 | size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { | ||
| 2267 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 2268 | |||
| 2269 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2270 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2271 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 2272 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 2273 | GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 2274 | GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 2275 | GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); | ||
| 2276 | GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); | ||
| 2277 | |||
| 2278 | size_t res = 0; | ||
| 2279 | |||
| 2280 | const bool has_mask = op->src[3] != nullptr; | ||
| 2281 | |||
| 2282 | // note: the non-vec kernel requires more extra memory, so always reserve for it | ||
| 2283 | GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG); | ||
| 2284 | |||
| 2285 | //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { | ||
| 2286 | if (false) { | ||
| 2287 | // note: always reserve the padding space to avoid graph reallocations | ||
| 2288 | //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; | ||
| 2289 | const bool has_kvpad = true; | ||
| 2290 | |||
| 2291 | if (has_kvpad) { | ||
| 2292 | res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( | ||
| 2293 | nb11*ne12*ne13 + | ||
| 2294 | nb21*ne22*ne23 + | ||
| 2295 | (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); | ||
| 2296 | } | ||
| 2297 | } else { | ||
| 2298 | //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; | ||
| 2299 | const bool has_kvpad = true; | ||
| 2300 | |||
| 2301 | if (has_kvpad) { | ||
| 2302 | res += OP_FLASH_ATTN_EXT_NCPSG*( | ||
| 2303 | nb11*ne12*ne13 + | ||
| 2304 | nb21*ne22*ne23 + | ||
| 2305 | (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); | ||
| 2306 | } | ||
| 2307 | } | ||
| 2308 | |||
| 2309 | return res; | ||
| 2310 | } | ||
| 2311 | |||
| 2312 | size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { | ||
| 2313 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 2314 | |||
| 2315 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2316 | //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2317 | //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 2318 | //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 2319 | //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 2320 | //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 2321 | GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); | ||
| 2322 | GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); | ||
| 2323 | |||
| 2324 | size_t res = 0; | ||
| 2325 | |||
| 2326 | const bool has_mask = op->src[3] != nullptr; | ||
| 2327 | |||
| 2328 | if (!has_mask) { | ||
| 2329 | return res; | ||
| 2330 | } | ||
| 2331 | |||
| 2332 | const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); | ||
| 2333 | |||
| 2334 | // this optimization is not useful for the vector kernels | ||
| 2335 | // note: always reserve the blk buffer to avoid graph reallocations | ||
| 2336 | //if (is_vec) { | ||
| 2337 | // return res; | ||
| 2338 | //} | ||
| 2339 | |||
| 2340 | const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG; | ||
| 2341 | const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; | ||
| 2342 | |||
| 2343 | const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; | ||
| 2344 | const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; | ||
| 2345 | |||
| 2346 | res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); | ||
| 2347 | |||
| 2348 | return res; | ||
| 2349 | } | ||
| 2350 | |||
| 2351 | size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { | ||
| 2352 | assert(op->op == GGML_OP_FLASH_ATTN_EXT); | ||
| 2353 | |||
| 2354 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2355 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2356 | //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 2357 | //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 2358 | GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 2359 | GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 2360 | //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); | ||
| 2361 | //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); | ||
| 2362 | |||
| 2363 | size_t res = 0; | ||
| 2364 | |||
| 2365 | // note: always reserve the temp buffer to avoid graph reallocations | ||
| 2366 | //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { | ||
| 2367 | if (true) { | ||
| 2368 | const int64_t nwg = 32; | ||
| 2369 | const int64_t ne01_max = std::min(ne01, 32); | ||
| 2370 | |||
| 2371 | // temp buffer for writing the results from each workgroup | ||
| 2372 | // - ne20: the size of the Value head | ||
| 2373 | // - + 2: the S and M values for each intermediate result | ||
| 2374 | res += ggml_type_size(GGML_TYPE_F32)*(ne01_max*ne02*ne03*nwg*(ne20 + 2)); | ||
| 2375 | } | ||
| 2376 | |||
| 2377 | return res; | ||
| 2378 | } | ||
| 2379 | |||
| 2380 | int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { | ||
| 2381 | ggml_tensor * op = ctx->node(idx); | ||
| 2382 | |||
| 2383 | ggml_metal_library_t lib = ctx->lib; | ||
| 2384 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 2385 | |||
| 2386 | const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx->dev); | ||
| 2387 | |||
| 2388 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2389 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2390 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 2391 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 2392 | GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); | ||
| 2393 | GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); | ||
| 2394 | GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); | ||
| 2395 | GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); | ||
| 2396 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 2397 | GGML_TENSOR_LOCALS( int32_t, nb, op, nb); | ||
| 2398 | |||
| 2399 | GGML_ASSERT(ne00 % 4 == 0); | ||
| 2400 | |||
| 2401 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); | ||
| 2402 | GGML_ASSERT(op->src[1]->type == op->src[2]->type); | ||
| 2403 | |||
| 2404 | //GGML_ASSERT(ggml_are_same_shape (src1, src2)); | ||
| 2405 | GGML_ASSERT(ne11 == ne21); | ||
| 2406 | GGML_ASSERT(ne12 == ne22); | ||
| 2407 | |||
| 2408 | GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); | ||
| 2409 | GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && | ||
| 2410 | "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); | ||
| 2411 | |||
| 2412 | float scale; | ||
| 2413 | float max_bias; | ||
| 2414 | float logit_softcap; | ||
| 2415 | |||
| 2416 | memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(scale)); | ||
| 2417 | memcpy(&max_bias, ((const int32_t *) op->op_params) + 1, sizeof(max_bias)); | ||
| 2418 | memcpy(&logit_softcap, ((const int32_t *) op->op_params) + 2, sizeof(logit_softcap)); | ||
| 2419 | |||
| 2420 | if (logit_softcap != 0.0f) { | ||
| 2421 | scale /= logit_softcap; | ||
| 2422 | } | ||
| 2423 | |||
| 2424 | const bool has_mask = op->src[3] != NULL; | ||
| 2425 | const bool has_sinks = op->src[4] != NULL; | ||
| 2426 | const bool has_bias = max_bias != 0.0f; | ||
| 2427 | const bool has_scap = logit_softcap != 0.0f; | ||
| 2428 | |||
| 2429 | const uint32_t n_head = op->src[0]->ne[2]; | ||
| 2430 | const int32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||
| 2431 | |||
| 2432 | const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); | ||
| 2433 | const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||
| 2434 | |||
| 2435 | GGML_ASSERT(ne01 < 65536); | ||
| 2436 | |||
| 2437 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 2438 | ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); | ||
| 2439 | ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); | ||
| 2440 | ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; | ||
| 2441 | ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; | ||
| 2442 | |||
| 2443 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 2444 | |||
| 2445 | ggml_metal_buffer_id bid_pad = bid_dst; | ||
| 2446 | bid_pad.offs += ggml_nbytes(op); | ||
| 2447 | |||
| 2448 | ggml_metal_buffer_id bid_blk = bid_pad; | ||
| 2449 | bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); | ||
| 2450 | |||
| 2451 | ggml_metal_buffer_id bid_tmp = bid_blk; | ||
| 2452 | bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); | ||
| 2453 | |||
| 2454 | if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { | ||
| 2455 | // half8x8 kernel | ||
| 2456 | const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup | ||
| 2457 | const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup | ||
| 2458 | |||
| 2459 | GGML_ASSERT(nqptg <= 32); | ||
| 2460 | GGML_ASSERT(nqptg % 8 == 0); | ||
| 2461 | GGML_ASSERT(ncpsg % 32 == 0); | ||
| 2462 | |||
| 2463 | bool need_sync = false; | ||
| 2464 | |||
| 2465 | const bool has_kvpad = ne11 % ncpsg != 0; | ||
| 2466 | |||
| 2467 | if (has_kvpad) { | ||
| 2468 | assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); | ||
| 2469 | |||
| 2470 | ggml_metal_kargs_flash_attn_ext_pad args0 = { | ||
| 2471 | /*.ne11 =*/ne11, | ||
| 2472 | /*.ne_12_2 =*/ne12, | ||
| 2473 | /*.ne_12_3 =*/ne13, | ||
| 2474 | /*.nb11 =*/nb11, | ||
| 2475 | /*.nb12 =*/nb12, | ||
| 2476 | /*.nb13 =*/nb13, | ||
| 2477 | /*.nb21 =*/nb21, | ||
| 2478 | /*.nb22 =*/nb22, | ||
| 2479 | /*.nb23 =*/nb23, | ||
| 2480 | /*.ne31 =*/ne31, | ||
| 2481 | /*.ne32 =*/ne32, | ||
| 2482 | /*.ne33 =*/ne33, | ||
| 2483 | /*.nb31 =*/nb31, | ||
| 2484 | /*.nb32 =*/nb32, | ||
| 2485 | /*.nb33 =*/nb33, | ||
| 2486 | }; | ||
| 2487 | |||
| 2488 | auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); | ||
| 2489 | |||
| 2490 | ggml_metal_encoder_set_pipeline(enc, pipeline0); | ||
| 2491 | ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); | ||
| 2492 | ggml_metal_encoder_set_buffer (enc, bid_src1, 1); | ||
| 2493 | ggml_metal_encoder_set_buffer (enc, bid_src2, 2); | ||
| 2494 | ggml_metal_encoder_set_buffer (enc, bid_src3, 3); | ||
| 2495 | ggml_metal_encoder_set_buffer (enc, bid_pad, 4); | ||
| 2496 | |||
| 2497 | assert(ne12 == ne22); | ||
| 2498 | assert(ne13 == ne23); | ||
| 2499 | |||
| 2500 | ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); | ||
| 2501 | |||
| 2502 | need_sync = true; | ||
| 2503 | } | ||
| 2504 | |||
| 2505 | if (has_mask) { | ||
| 2506 | assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); | ||
| 2507 | |||
| 2508 | ggml_metal_kargs_flash_attn_ext_blk args0 = { | ||
| 2509 | /*.ne01 =*/ ne01, | ||
| 2510 | /*.ne30 =*/ ne30, | ||
| 2511 | /*.ne31 =*/ ne31, | ||
| 2512 | /*.ne32 =*/ ne32, | ||
| 2513 | /*.ne33 =*/ ne33, | ||
| 2514 | /*.nb31 =*/ nb31, | ||
| 2515 | /*.nb32 =*/ nb32, | ||
| 2516 | /*.nb33 =*/ nb33, | ||
| 2517 | }; | ||
| 2518 | |||
| 2519 | auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); | ||
| 2520 | |||
| 2521 | ggml_metal_encoder_set_pipeline(enc, pipeline0); | ||
| 2522 | ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); | ||
| 2523 | ggml_metal_encoder_set_buffer (enc, bid_src3, 1); | ||
| 2524 | ggml_metal_encoder_set_buffer (enc, bid_blk, 2); | ||
| 2525 | |||
| 2526 | const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); | ||
| 2527 | const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); | ||
| 2528 | |||
| 2529 | ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); | ||
| 2530 | |||
| 2531 | need_sync = true; | ||
| 2532 | } | ||
| 2533 | |||
| 2534 | if (need_sync) { | ||
| 2535 | ggml_metal_op_concurrency_reset(ctx); | ||
| 2536 | } | ||
| 2537 | |||
| 2538 | const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; | ||
| 2539 | |||
| 2540 | // 2*(2*ncpsg) | ||
| 2541 | // ncpsg soft_max values + ncpsg mask values | ||
| 2542 | // | ||
| 2543 | // 16*32*(nsg) | ||
| 2544 | // the shared memory needed for the simdgroups to load the KV cache | ||
| 2545 | // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG | ||
| 2546 | // | ||
| 2547 | #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*GGML_PAD(ne20, 64) + 2*(2*ncpsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16)) | ||
| 2548 | |||
| 2549 | //int64_t nsgmax = 4; | ||
| 2550 | // | ||
| 2551 | //if (is_q) { | ||
| 2552 | // nsgmax = 2; | ||
| 2553 | // while (true) { | ||
| 2554 | // const size_t smem = FATTN_SMEM(nsgmax); | ||
| 2555 | // if (smem > props_dev->max_theadgroup_memory_size) { | ||
| 2556 | // break; | ||
| 2557 | // } | ||
| 2558 | // nsgmax *= 2; | ||
| 2559 | // } | ||
| 2560 | // nsgmax /= 2; | ||
| 2561 | //} | ||
| 2562 | |||
| 2563 | // simdgroups per threadgroup (a.k.a. warps) | ||
| 2564 | //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; | ||
| 2565 | int32_t nsg = ne00 >= 512 ? 8 : 4; | ||
| 2566 | |||
| 2567 | const size_t smem = FATTN_SMEM(nsg); | ||
| 2568 | |||
| 2569 | ggml_metal_kargs_flash_attn_ext args = { | ||
| 2570 | /*.ne01 =*/ ne01, | ||
| 2571 | /*.ne02 =*/ ne02, | ||
| 2572 | /*.ne03 =*/ ne03, | ||
| 2573 | /*.nb01 =*/ nb01, | ||
| 2574 | /*.nb02 =*/ nb02, | ||
| 2575 | /*.nb03 =*/ nb03, | ||
| 2576 | /*.ne11 =*/ ne11, | ||
| 2577 | /*.ne_12_2 =*/ ne12, | ||
| 2578 | /*.ne_12_3 =*/ ne13, | ||
| 2579 | /*.ns10 =*/ int32_t(nb11/nb10), | ||
| 2580 | /*.nb11 =*/ nb11, | ||
| 2581 | /*.nb12 =*/ nb12, | ||
| 2582 | /*.nb13 =*/ nb13, | ||
| 2583 | /*.ns20 =*/ int32_t(nb21/nb20), | ||
| 2584 | /*.nb21 =*/ nb21, | ||
| 2585 | /*.nb22 =*/ nb22, | ||
| 2586 | /*.nb23 =*/ nb23, | ||
| 2587 | /*.ne31 =*/ ne31, | ||
| 2588 | /*.ne32 =*/ ne32, | ||
| 2589 | /*.ne33 =*/ ne33, | ||
| 2590 | /*.nb31 =*/ nb31, | ||
| 2591 | /*.nb32 =*/ nb32, | ||
| 2592 | /*.nb33 =*/ nb33, | ||
| 2593 | /*.ne1 =*/ ne1, | ||
| 2594 | /*.ne2 =*/ ne2, | ||
| 2595 | /*.ne3 =*/ ne3, | ||
| 2596 | /*.scale =*/ scale, | ||
| 2597 | /*.max_bias =*/ max_bias, | ||
| 2598 | /*.m0 =*/ m0, | ||
| 2599 | /*.m1 =*/ m1, | ||
| 2600 | /*.n_head_log2 =*/ n_head_log2, | ||
| 2601 | /*.logit_softcap =*/ logit_softcap, | ||
| 2602 | }; | ||
| 2603 | |||
| 2604 | auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); | ||
| 2605 | |||
| 2606 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 2607 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 2608 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 2609 | ggml_metal_encoder_set_buffer (enc, bid_src1, 2); | ||
| 2610 | ggml_metal_encoder_set_buffer (enc, bid_src2, 3); | ||
| 2611 | ggml_metal_encoder_set_buffer (enc, bid_src3, 4); | ||
| 2612 | ggml_metal_encoder_set_buffer (enc, bid_src4, 5); | ||
| 2613 | ggml_metal_encoder_set_buffer (enc, bid_pad, 6); | ||
| 2614 | ggml_metal_encoder_set_buffer (enc, bid_blk, 7); | ||
| 2615 | ggml_metal_encoder_set_buffer (enc, bid_dst, 8); | ||
| 2616 | |||
| 2617 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 2618 | |||
| 2619 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03, 32, nsg, 1); | ||
| 2620 | #undef FATTN_SMEM | ||
| 2621 | } else { | ||
| 2622 | // half4x4 kernel | ||
| 2623 | const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup | ||
| 2624 | const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! | ||
| 2625 | const int nhptg = 1; // heads per threadgroup | ||
| 2626 | |||
| 2627 | GGML_ASSERT(nqptg <= 32); | ||
| 2628 | GGML_ASSERT(nqptg % 1 == 0); | ||
| 2629 | GGML_ASSERT(ncpsg % 32 == 0); | ||
| 2630 | |||
| 2631 | bool need_sync = false; | ||
| 2632 | |||
| 2633 | const bool has_kvpad = ne11 % ncpsg != 0; | ||
| 2634 | |||
| 2635 | if (has_kvpad) { | ||
| 2636 | assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); | ||
| 2637 | |||
| 2638 | ggml_metal_kargs_flash_attn_ext_pad args0 = { | ||
| 2639 | /*.ne11 =*/ne11, | ||
| 2640 | /*.ne_12_2 =*/ne12, | ||
| 2641 | /*.ne_12_3 =*/ne13, | ||
| 2642 | /*.nb11 =*/nb11, | ||
| 2643 | /*.nb12 =*/nb12, | ||
| 2644 | /*.nb13 =*/nb13, | ||
| 2645 | /*.nb21 =*/nb21, | ||
| 2646 | /*.nb22 =*/nb22, | ||
| 2647 | /*.nb23 =*/nb23, | ||
| 2648 | /*.ne31 =*/ne31, | ||
| 2649 | /*.ne32 =*/ne32, | ||
| 2650 | /*.ne33 =*/ne33, | ||
| 2651 | /*.nb31 =*/nb31, | ||
| 2652 | /*.nb32 =*/nb32, | ||
| 2653 | /*.nb33 =*/nb33, | ||
| 2654 | }; | ||
| 2655 | |||
| 2656 | auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); | ||
| 2657 | |||
| 2658 | ggml_metal_encoder_set_pipeline(enc, pipeline0); | ||
| 2659 | ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); | ||
| 2660 | ggml_metal_encoder_set_buffer (enc, bid_src1, 1); | ||
| 2661 | ggml_metal_encoder_set_buffer (enc, bid_src2, 2); | ||
| 2662 | ggml_metal_encoder_set_buffer (enc, bid_src3, 3); | ||
| 2663 | ggml_metal_encoder_set_buffer (enc, bid_pad, 4); | ||
| 2664 | |||
| 2665 | assert(ne12 == ne22); | ||
| 2666 | assert(ne13 == ne23); | ||
| 2667 | |||
| 2668 | ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); | ||
| 2669 | |||
| 2670 | need_sync = true; | ||
| 2671 | } | ||
| 2672 | |||
| 2673 | if (need_sync) { | ||
| 2674 | ggml_metal_op_concurrency_reset(ctx); | ||
| 2675 | } | ||
| 2676 | |||
| 2677 | // note: for simplicity assume the K is larger or equal than V | ||
| 2678 | GGML_ASSERT(ne10 >= ne20); | ||
| 2679 | |||
| 2680 | // ne00 + 2*ncpsg*(nsg) | ||
| 2681 | // for each query, we load it as f16 in shared memory (ne00) | ||
| 2682 | // and store the soft_max values and the mask | ||
| 2683 | // | ||
| 2684 | // ne20*(nsg) | ||
| 2685 | // each simdgroup has a full f32 head vector in shared mem to accumulate results | ||
| 2686 | // | ||
| 2687 | #define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16)) | ||
| 2688 | |||
| 2689 | int64_t nsg = 1; | ||
| 2690 | |||
| 2691 | // workgroups | ||
| 2692 | // each workgroup handles nsg*nkpsg cache values | ||
| 2693 | int32_t nwg = 1; | ||
| 2694 | if (false) { | ||
| 2695 | // for small KV caches, we could launch a single workgroup and write the results directly to dst/ | ||
| 2696 | // however, this does not lead to significant improvement, so disabled | ||
| 2697 | nwg = 1; | ||
| 2698 | nsg = 4; | ||
| 2699 | } else { | ||
| 2700 | nwg = 32; | ||
| 2701 | nsg = 1; | ||
| 2702 | while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) { | ||
| 2703 | nsg *= 2; | ||
| 2704 | } | ||
| 2705 | } | ||
| 2706 | |||
| 2707 | ggml_metal_kargs_flash_attn_ext_vec args = { | ||
| 2708 | /*.ne01 =*/ ne01, | ||
| 2709 | /*.ne02 =*/ ne02, | ||
| 2710 | /*.ne03 =*/ ne03, | ||
| 2711 | /*.nb01 =*/ nb01, | ||
| 2712 | /*.nb02 =*/ nb02, | ||
| 2713 | /*.nb03 =*/ nb03, | ||
| 2714 | /*.ne11 =*/ ne11, | ||
| 2715 | /*.ne_12_2 =*/ ne12, | ||
| 2716 | /*.ne_12_3 =*/ ne13, | ||
| 2717 | /*.ns10 =*/ int32_t(nb11/nb10), | ||
| 2718 | /*.nb11 =*/ nb11, | ||
| 2719 | /*.nb12 =*/ nb12, | ||
| 2720 | /*.nb13 =*/ nb13, | ||
| 2721 | /*.ns20 =*/ int32_t(nb21/nb20), | ||
| 2722 | /*.nb21 =*/ nb21, | ||
| 2723 | /*.nb22 =*/ nb22, | ||
| 2724 | /*.nb23 =*/ nb23, | ||
| 2725 | /*.ne31 =*/ ne31, | ||
| 2726 | /*.ne32 =*/ ne32, | ||
| 2727 | /*.ne33 =*/ ne33, | ||
| 2728 | /*.nb31 =*/ nb31, | ||
| 2729 | /*.nb32 =*/ nb32, | ||
| 2730 | /*.nb33 =*/ nb33, | ||
| 2731 | /*.ne1 =*/ ne1, | ||
| 2732 | /*.ne2 =*/ ne2, | ||
| 2733 | /*.ne3 =*/ ne3, | ||
| 2734 | /*.scale =*/ scale, | ||
| 2735 | /*.max_bias =*/ max_bias, | ||
| 2736 | /*.m0 =*/ m0, | ||
| 2737 | /*.m1 =*/ m1, | ||
| 2738 | /*.n_head_log2 =*/ n_head_log2, | ||
| 2739 | /*.logit_softcap =*/ logit_softcap, | ||
| 2740 | }; | ||
| 2741 | |||
| 2742 | auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); | ||
| 2743 | |||
| 2744 | GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 2745 | |||
| 2746 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 2747 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 2748 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 2749 | ggml_metal_encoder_set_buffer (enc, bid_src1, 2); | ||
| 2750 | ggml_metal_encoder_set_buffer (enc, bid_src2, 3); | ||
| 2751 | ggml_metal_encoder_set_buffer (enc, bid_src3, 4); | ||
| 2752 | ggml_metal_encoder_set_buffer (enc, bid_src4, 5); | ||
| 2753 | |||
| 2754 | const size_t smem = FATTN_SMEM(nsg); | ||
| 2755 | |||
| 2756 | //printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, props_dev->max_theadgroup_memory_size, (int) nsg, (int) nsgmax); | ||
| 2757 | GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); | ||
| 2758 | |||
| 2759 | if (nwg == 1) { | ||
| 2760 | assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); | ||
| 2761 | |||
| 2762 | // using 1 workgroup -> write the result directly into dst | ||
| 2763 | ggml_metal_encoder_set_buffer(enc, bid_pad, 6); | ||
| 2764 | ggml_metal_encoder_set_buffer(enc, bid_dst, 7); | ||
| 2765 | |||
| 2766 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 2767 | |||
| 2768 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); | ||
| 2769 | } else { | ||
| 2770 | // sanity checks | ||
| 2771 | assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); | ||
| 2772 | |||
| 2773 | GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); | ||
| 2774 | GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); | ||
| 2775 | |||
| 2776 | // write the results from each workgroup into a temp buffer | ||
| 2777 | ggml_metal_encoder_set_buffer(enc, bid_pad, 6); | ||
| 2778 | ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); | ||
| 2779 | |||
| 2780 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 2781 | ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); | ||
| 2782 | |||
| 2783 | // sync the 2 kernels | ||
| 2784 | ggml_metal_op_concurrency_reset(ctx); | ||
| 2785 | |||
| 2786 | // reduce the results from the workgroups | ||
| 2787 | { | ||
| 2788 | const int32_t nrows = ne1*ne2*ne3; | ||
| 2789 | |||
| 2790 | ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { | ||
| 2791 | nrows, | ||
| 2792 | }; | ||
| 2793 | |||
| 2794 | auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg); | ||
| 2795 | |||
| 2796 | ggml_metal_encoder_set_pipeline(enc, pipeline0); | ||
| 2797 | ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); | ||
| 2798 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); | ||
| 2799 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 2800 | |||
| 2801 | ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, 32*nwg, 1, 1); | ||
| 2802 | } | ||
| 2803 | } | ||
| 2804 | #undef FATTN_SMEM | ||
| 2805 | } | ||
| 2806 | |||
| 2807 | return 1; | ||
| 2808 | } | ||
| 2809 | |||
| 2810 | int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { | ||
| 2811 | ggml_tensor * op = ctx->node(idx); | ||
| 2812 | |||
| 2813 | ggml_metal_library_t lib = ctx->lib; | ||
| 2814 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 2815 | |||
| 2816 | const bool use_fusion = ctx->use_fusion; | ||
| 2817 | |||
| 2818 | const int debug_fusion = ctx->debug_fusion; | ||
| 2819 | |||
| 2820 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2821 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2822 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 2823 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 2824 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 2825 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 2826 | |||
| 2827 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); | ||
| 2828 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 2829 | |||
| 2830 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 2831 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); | ||
| 2832 | |||
| 2833 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 2834 | ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); | ||
| 2835 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 2836 | |||
| 2837 | ggml_metal_kargs_bin args = { | ||
| 2838 | /*.ne00 =*/ ne00, | ||
| 2839 | /*.ne01 =*/ ne01, | ||
| 2840 | /*.ne02 =*/ ne02, | ||
| 2841 | /*.ne03 =*/ ne03, | ||
| 2842 | /*.nb00 =*/ nb00, | ||
| 2843 | /*.nb01 =*/ nb01, | ||
| 2844 | /*.nb02 =*/ nb02, | ||
| 2845 | /*.nb03 =*/ nb03, | ||
| 2846 | /*.ne10 =*/ ne10, | ||
| 2847 | /*.ne11 =*/ ne11, | ||
| 2848 | /*.ne12 =*/ ne12, | ||
| 2849 | /*.ne13 =*/ ne13, | ||
| 2850 | /*.nb10 =*/ nb10, | ||
| 2851 | /*.nb11 =*/ nb11, | ||
| 2852 | /*.nb12 =*/ nb12, | ||
| 2853 | /*.nb13 =*/ nb13, | ||
| 2854 | /*.ne0 =*/ ne0, | ||
| 2855 | /*.ne1 =*/ ne1, | ||
| 2856 | /*.ne2 =*/ ne2, | ||
| 2857 | /*.ne3 =*/ ne3, | ||
| 2858 | /*.nb0 =*/ nb0, | ||
| 2859 | /*.nb1 =*/ nb1, | ||
| 2860 | /*.nb2 =*/ nb2, | ||
| 2861 | /*.nb3 =*/ nb3, | ||
| 2862 | /*.offs =*/ 0, | ||
| 2863 | /*.o1 =*/ { bid_src1.offs }, | ||
| 2864 | }; | ||
| 2865 | |||
| 2866 | ggml_op fops[8]; | ||
| 2867 | |||
| 2868 | int n_fuse = 1; | ||
| 2869 | |||
| 2870 | // c[0] = add(a, b[0]) | ||
| 2871 | // c[1] = add(c[0], b[1]) | ||
| 2872 | // c[2] = add(c[1], b[2]) | ||
| 2873 | // ... | ||
| 2874 | if (use_fusion) { | ||
| 2875 | fops[0] = GGML_OP_ADD; | ||
| 2876 | fops[1] = GGML_OP_ADD; | ||
| 2877 | fops[2] = GGML_OP_ADD; | ||
| 2878 | fops[3] = GGML_OP_ADD; | ||
| 2879 | fops[4] = GGML_OP_ADD; | ||
| 2880 | fops[5] = GGML_OP_ADD; | ||
| 2881 | fops[6] = GGML_OP_ADD; | ||
| 2882 | fops[7] = GGML_OP_ADD; | ||
| 2883 | |||
| 2884 | // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops | ||
| 2885 | // across splits. idx_end indicates the last node in the current split | ||
| 2886 | for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { | ||
| 2887 | if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { | ||
| 2888 | break; | ||
| 2889 | } | ||
| 2890 | |||
| 2891 | ggml_tensor * f0 = ctx->node(idx + n_fuse); | ||
| 2892 | ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); | ||
| 2893 | |||
| 2894 | if (f0 != f1->src[0]) { | ||
| 2895 | break; | ||
| 2896 | } | ||
| 2897 | |||
| 2898 | // b[0] === b[1] === ... | ||
| 2899 | if (!ggml_are_same_layout(f0->src[1], f1->src[1])) { | ||
| 2900 | break; | ||
| 2901 | } | ||
| 2902 | |||
| 2903 | // only fuse ops if src1 is in the same Metal buffer | ||
| 2904 | ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]); | ||
| 2905 | if (bid_fuse.metal != bid_src1.metal) { | ||
| 2906 | break; | ||
| 2907 | } | ||
| 2908 | |||
| 2909 | //ctx->fuse_cnt[ops[n_fuse + 1]->op]++; | ||
| 2910 | |||
| 2911 | args.o1[n_fuse + 1] = bid_fuse.offs; | ||
| 2912 | } | ||
| 2913 | |||
| 2914 | ++n_fuse; | ||
| 2915 | |||
| 2916 | if (debug_fusion > 1 && n_fuse > 1) { | ||
| 2917 | GGML_LOG_DEBUG("%s: fuse: ADD x %d\n", __func__, n_fuse); | ||
| 2918 | } | ||
| 2919 | } | ||
| 2920 | |||
| 2921 | // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer | ||
| 2922 | bid_src1.offs = 0; | ||
| 2923 | |||
| 2924 | struct ggml_metal_pipeline_with_params pipeline; | ||
| 2925 | |||
| 2926 | pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); | ||
| 2927 | |||
| 2928 | if (n_fuse > 1) { | ||
| 2929 | bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); | ||
| 2930 | |||
| 2931 | for (int i = 1; i < n_fuse; ++i) { | ||
| 2932 | if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { | ||
| 2933 | ggml_metal_op_concurrency_reset(ctx); | ||
| 2934 | |||
| 2935 | break; | ||
| 2936 | } | ||
| 2937 | } | ||
| 2938 | } | ||
| 2939 | |||
| 2940 | if (pipeline.c4) { | ||
| 2941 | args.ne00 = ne00/4; | ||
| 2942 | args.ne10 = ne10/4; | ||
| 2943 | args.ne0 = ne0/4; | ||
| 2944 | } | ||
| 2945 | |||
| 2946 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 2947 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 2948 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 2949 | ggml_metal_encoder_set_buffer (enc, bid_src1, 2); | ||
| 2950 | ggml_metal_encoder_set_buffer (enc, bid_dst, 3); | ||
| 2951 | |||
| 2952 | if (pipeline.cnt) { | ||
| 2953 | const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); | ||
| 2954 | |||
| 2955 | ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); | ||
| 2956 | } else { | ||
| 2957 | const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 2958 | |||
| 2959 | int nth = 1; | ||
| 2960 | |||
| 2961 | while (2*nth < args.ne0 && nth < nth_max) { | ||
| 2962 | nth *= 2; | ||
| 2963 | } | ||
| 2964 | |||
| 2965 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 2966 | } | ||
| 2967 | |||
| 2968 | return n_fuse; | ||
| 2969 | } | ||
| 2970 | |||
| 2971 | int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { | ||
| 2972 | ggml_tensor * op = ctx->node(idx); | ||
| 2973 | |||
| 2974 | ggml_metal_library_t lib = ctx->lib; | ||
| 2975 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 2976 | |||
| 2977 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 2978 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 2979 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 2980 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 2981 | |||
| 2982 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 2983 | |||
| 2984 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 2985 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 2986 | |||
| 2987 | float eps; | ||
| 2988 | memcpy(&eps, op->op_params, sizeof(float)); | ||
| 2989 | |||
| 2990 | ggml_metal_kargs_l2_norm args = { | ||
| 2991 | /*.ne00 =*/ ne00, | ||
| 2992 | /*.ne01 =*/ ne01, | ||
| 2993 | /*.ne02 =*/ ne02, | ||
| 2994 | /*.ne03 =*/ ne03, | ||
| 2995 | /*.nb00 =*/ nb00, | ||
| 2996 | /*.nb01 =*/ nb01, | ||
| 2997 | /*.nb02 =*/ nb02, | ||
| 2998 | /*.nb03 =*/ nb03, | ||
| 2999 | /*.ne0 =*/ ne0, | ||
| 3000 | /*.ne1 =*/ ne1, | ||
| 3001 | /*.ne2 =*/ ne2, | ||
| 3002 | /*.ne3 =*/ ne3, | ||
| 3003 | /*.nb0 =*/ nb0, | ||
| 3004 | /*.nb1 =*/ nb1, | ||
| 3005 | /*.nb2 =*/ nb2, | ||
| 3006 | /*.nb3 =*/ nb3, | ||
| 3007 | /*.eps =*/ eps, | ||
| 3008 | }; | ||
| 3009 | |||
| 3010 | auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); | ||
| 3011 | |||
| 3012 | if (pipeline.c4) { | ||
| 3013 | args.ne00 = ne00/4; | ||
| 3014 | args.ne0 = ne0/4; | ||
| 3015 | } | ||
| 3016 | |||
| 3017 | int nth = 32; // SIMD width | ||
| 3018 | |||
| 3019 | while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 3020 | nth *= 2; | ||
| 3021 | } | ||
| 3022 | |||
| 3023 | nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 3024 | |||
| 3025 | const size_t smem = pipeline.smem; | ||
| 3026 | |||
| 3027 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3028 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3029 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 3030 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 3031 | |||
| 3032 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 3033 | |||
| 3034 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 3035 | |||
| 3036 | return 1; | ||
| 3037 | } | ||
| 3038 | |||
| 3039 | int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { | ||
| 3040 | ggml_tensor * op = ctx->node(idx); | ||
| 3041 | |||
| 3042 | ggml_metal_library_t lib = ctx->lib; | ||
| 3043 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3044 | |||
| 3045 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3046 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3047 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3048 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3049 | |||
| 3050 | const int32_t ngrp = ((const int32_t *) op->op_params)[0]; | ||
| 3051 | |||
| 3052 | float eps; | ||
| 3053 | memcpy(&eps, op->op_params + 1, sizeof(float)); | ||
| 3054 | |||
| 3055 | ggml_metal_kargs_group_norm args = { | ||
| 3056 | /*.ne00 =*/ ne00, | ||
| 3057 | /*.ne01 =*/ ne01, | ||
| 3058 | /*.ne02 =*/ ne02, | ||
| 3059 | /*.nb00 =*/ nb00, | ||
| 3060 | /*.nb01 =*/ nb01, | ||
| 3061 | /*.nb02 =*/ nb02, | ||
| 3062 | /*.ngrp =*/ ngrp, | ||
| 3063 | /*.eps =*/ eps, | ||
| 3064 | }; | ||
| 3065 | |||
| 3066 | auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op); | ||
| 3067 | |||
| 3068 | int nth = 32; // SIMD width | ||
| 3069 | //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 3070 | // nth *= 2; | ||
| 3071 | //} | ||
| 3072 | |||
| 3073 | //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 3074 | //nth = std::min(nth, ne00/4); | ||
| 3075 | |||
| 3076 | const size_t smem = pipeline.smem; | ||
| 3077 | |||
| 3078 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3079 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3080 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3081 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 3082 | |||
| 3083 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 3084 | |||
| 3085 | ggml_metal_encoder_dispatch_threadgroups(enc, ngrp, 1, 1, nth, 1, 1); | ||
| 3086 | |||
| 3087 | return 1; | ||
| 3088 | } | ||
| 3089 | |||
| 3090 | int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { | ||
| 3091 | ggml_tensor * op = ctx->node(idx); | ||
| 3092 | |||
| 3093 | ggml_metal_library_t lib = ctx->lib; | ||
| 3094 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3095 | |||
| 3096 | const bool use_fusion = ctx->use_fusion; | ||
| 3097 | |||
| 3098 | const int debug_fusion = ctx->debug_fusion; | ||
| 3099 | |||
| 3100 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3101 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3102 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3103 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3104 | |||
| 3105 | float eps; | ||
| 3106 | memcpy(&eps, op->op_params, sizeof(float)); | ||
| 3107 | |||
| 3108 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 3109 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 3110 | |||
| 3111 | ggml_metal_kargs_norm args = { | ||
| 3112 | /*.ne00 =*/ ne00, | ||
| 3113 | /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00, | ||
| 3114 | /*.nb1 =*/ nb1, | ||
| 3115 | /*.nb2 =*/ nb2, | ||
| 3116 | /*.nb3 =*/ nb3, | ||
| 3117 | /*.eps =*/ eps, | ||
| 3118 | /*.nef1 =*/ { ne01 }, | ||
| 3119 | /*.nef2 =*/ { ne02 }, | ||
| 3120 | /*.nef3 =*/ { ne03 }, | ||
| 3121 | /*.nbf1 =*/ { nb01 }, | ||
| 3122 | /*.nbf2 =*/ { nb02 }, | ||
| 3123 | /*.nbf3 =*/ { nb03 }, | ||
| 3124 | }; | ||
| 3125 | |||
| 3126 | ggml_op fops[8]; | ||
| 3127 | |||
| 3128 | int n_fuse = 1; | ||
| 3129 | |||
| 3130 | ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 }; | ||
| 3131 | |||
| 3132 | // d[0] = norm(a) | ||
| 3133 | // d[1] = mul(d[0], b) | ||
| 3134 | // d[2] = add(d[1], c) | ||
| 3135 | if (use_fusion) { | ||
| 3136 | fops[0] = op->op; | ||
| 3137 | fops[1] = GGML_OP_MUL; | ||
| 3138 | fops[2] = GGML_OP_ADD; | ||
| 3139 | |||
| 3140 | for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { | ||
| 3141 | if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) { | ||
| 3142 | break; | ||
| 3143 | } | ||
| 3144 | |||
| 3145 | ggml_tensor * f0 = ctx->node(idx + n_fuse); | ||
| 3146 | ggml_tensor * f1 = ctx->node(idx + n_fuse + 1); | ||
| 3147 | |||
| 3148 | if (f0 != f1->src[0]) { | ||
| 3149 | break; | ||
| 3150 | } | ||
| 3151 | |||
| 3152 | if (f1->src[1]->ne[0] != op->ne[0]) { | ||
| 3153 | break; | ||
| 3154 | } | ||
| 3155 | |||
| 3156 | if (!ggml_is_contiguous_rows(f1->src[1])) { | ||
| 3157 | break; | ||
| 3158 | } | ||
| 3159 | |||
| 3160 | if (f1->type != GGML_TYPE_F32) { | ||
| 3161 | break; | ||
| 3162 | } | ||
| 3163 | |||
| 3164 | //ctx->fuse_cnt[f1->op]++; | ||
| 3165 | |||
| 3166 | bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]); | ||
| 3167 | |||
| 3168 | args.nef1[n_fuse + 1] = f1->src[1]->ne[1]; | ||
| 3169 | args.nef2[n_fuse + 1] = f1->src[1]->ne[2]; | ||
| 3170 | args.nef3[n_fuse + 1] = f1->src[1]->ne[3]; | ||
| 3171 | |||
| 3172 | args.nbf1[n_fuse + 1] = f1->src[1]->nb[1]; | ||
| 3173 | args.nbf2[n_fuse + 1] = f1->src[1]->nb[2]; | ||
| 3174 | args.nbf3[n_fuse + 1] = f1->src[1]->nb[3]; | ||
| 3175 | } | ||
| 3176 | |||
| 3177 | ++n_fuse; | ||
| 3178 | |||
| 3179 | if (debug_fusion > 1 && n_fuse > 1) { | ||
| 3180 | if (n_fuse == 2) { | ||
| 3181 | GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op)); | ||
| 3182 | } | ||
| 3183 | if (n_fuse == 3) { | ||
| 3184 | GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op)); | ||
| 3185 | } | ||
| 3186 | } | ||
| 3187 | } | ||
| 3188 | |||
| 3189 | if (n_fuse > 1) { | ||
| 3190 | bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); | ||
| 3191 | |||
| 3192 | for (int i = 1; i < n_fuse; ++i) { | ||
| 3193 | if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) { | ||
| 3194 | ggml_metal_op_concurrency_reset(ctx); | ||
| 3195 | |||
| 3196 | break; | ||
| 3197 | } | ||
| 3198 | } | ||
| 3199 | } | ||
| 3200 | |||
| 3201 | auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse); | ||
| 3202 | |||
| 3203 | int nth = 32; // SIMD width | ||
| 3204 | |||
| 3205 | while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 3206 | nth *= 2; | ||
| 3207 | } | ||
| 3208 | |||
| 3209 | nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 3210 | nth = std::min(nth, args.ne00_t); | ||
| 3211 | |||
| 3212 | const size_t smem = pipeline.smem; | ||
| 3213 | |||
| 3214 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3215 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3216 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 3217 | ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2); | ||
| 3218 | ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3); | ||
| 3219 | ggml_metal_encoder_set_buffer (enc, bid_dst, 4); | ||
| 3220 | |||
| 3221 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 3222 | |||
| 3223 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 3224 | |||
| 3225 | return n_fuse; | ||
| 3226 | } | ||
| 3227 | |||
| 3228 | int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { | ||
| 3229 | ggml_tensor * op = ctx->node(idx); | ||
| 3230 | |||
| 3231 | ggml_metal_library_t lib = ctx->lib; | ||
| 3232 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3233 | |||
| 3234 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3235 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3236 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 3237 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 3238 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3239 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3240 | |||
| 3241 | // make sure we have one or more position id(ne10) per token(ne02) | ||
| 3242 | GGML_ASSERT(ne10 % ne02 == 0); | ||
| 3243 | GGML_ASSERT(ne10 >= ne02); | ||
| 3244 | |||
| 3245 | const int nth = std::min(1024, ne00); | ||
| 3246 | |||
| 3247 | const int n_past = ((const int32_t *) op->op_params)[0]; | ||
| 3248 | const int n_dims = ((const int32_t *) op->op_params)[1]; | ||
| 3249 | //const int mode = ((const int32_t *) op->op_params)[2]; | ||
| 3250 | // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal | ||
| 3251 | const int n_ctx_orig = ((const int32_t *) op->op_params)[4]; | ||
| 3252 | |||
| 3253 | float freq_base; | ||
| 3254 | float freq_scale; | ||
| 3255 | float ext_factor; | ||
| 3256 | float attn_factor; | ||
| 3257 | float beta_fast; | ||
| 3258 | float beta_slow; | ||
| 3259 | |||
| 3260 | memcpy(&freq_base, (const int32_t *) op->op_params + 5, sizeof(float)); | ||
| 3261 | memcpy(&freq_scale, (const int32_t *) op->op_params + 6, sizeof(float)); | ||
| 3262 | memcpy(&ext_factor, (const int32_t *) op->op_params + 7, sizeof(float)); | ||
| 3263 | memcpy(&attn_factor, (const int32_t *) op->op_params + 8, sizeof(float)); | ||
| 3264 | memcpy(&beta_fast, (const int32_t *) op->op_params + 9, sizeof(float)); | ||
| 3265 | memcpy(&beta_slow, (const int32_t *) op->op_params + 10, sizeof(float)); | ||
| 3266 | |||
| 3267 | // mrope | ||
| 3268 | const int sect_0 = ((const int32_t *) op->op_params)[11]; | ||
| 3269 | const int sect_1 = ((const int32_t *) op->op_params)[12]; | ||
| 3270 | const int sect_2 = ((const int32_t *) op->op_params)[13]; | ||
| 3271 | const int sect_3 = ((const int32_t *) op->op_params)[14]; | ||
| 3272 | |||
| 3273 | ggml_metal_kargs_rope args = { | ||
| 3274 | /*.ne00 =*/ ne00, | ||
| 3275 | /*.ne01 =*/ ne01, | ||
| 3276 | /*.ne02 =*/ ne02, | ||
| 3277 | /*.ne03 =*/ ne03, | ||
| 3278 | /*.nb00 =*/ nb00, | ||
| 3279 | /*.nb01 =*/ nb01, | ||
| 3280 | /*.nb02 =*/ nb02, | ||
| 3281 | /*.nb03 =*/ nb03, | ||
| 3282 | /*.ne0 =*/ ne0, | ||
| 3283 | /*.ne1 =*/ ne1, | ||
| 3284 | /*.ne2 =*/ ne2, | ||
| 3285 | /*.ne3 =*/ ne3, | ||
| 3286 | /*.nb0 =*/ nb0, | ||
| 3287 | /*.nb1 =*/ nb1, | ||
| 3288 | /*.nb2 =*/ nb2, | ||
| 3289 | /*.nb3 =*/ nb3, | ||
| 3290 | /*.n_past =*/ n_past, | ||
| 3291 | /*.n_dims =*/ n_dims, | ||
| 3292 | /*.n_ctx_orig =*/ n_ctx_orig, | ||
| 3293 | /*.freq_base =*/ freq_base, | ||
| 3294 | /*.freq_scale =*/ freq_scale, | ||
| 3295 | /*.ext_factor =*/ ext_factor, | ||
| 3296 | /*.attn_factor =*/ attn_factor, | ||
| 3297 | /*.beta_fast =*/ beta_fast, | ||
| 3298 | /*.beta_slow =*/ beta_slow, | ||
| 3299 | /* sect_0 =*/ sect_0, | ||
| 3300 | /* sect_1 =*/ sect_1, | ||
| 3301 | /* sect_2 =*/ sect_2, | ||
| 3302 | /* sect_3 =*/ sect_3, | ||
| 3303 | /* src2 =*/ op->src[2] != nullptr, | ||
| 3304 | }; | ||
| 3305 | |||
| 3306 | auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op); | ||
| 3307 | |||
| 3308 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3309 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3310 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3311 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 3312 | if (op->src[2]) { | ||
| 3313 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); | ||
| 3314 | } else { | ||
| 3315 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3); | ||
| 3316 | } | ||
| 3317 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); | ||
| 3318 | |||
| 3319 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 3320 | |||
| 3321 | return 1; | ||
| 3322 | } | ||
| 3323 | |||
| 3324 | int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { | ||
| 3325 | ggml_tensor * op = ctx->node(idx); | ||
| 3326 | |||
| 3327 | ggml_metal_library_t lib = ctx->lib; | ||
| 3328 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3329 | |||
| 3330 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3331 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3332 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3333 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3334 | |||
| 3335 | const int32_t s0 = ((const int32_t *)(op->op_params))[0]; | ||
| 3336 | const int32_t s1 = ((const int32_t *)(op->op_params))[1]; | ||
| 3337 | const int32_t p0 = ((const int32_t *)(op->op_params))[2]; | ||
| 3338 | const int32_t p1 = ((const int32_t *)(op->op_params))[3]; | ||
| 3339 | const int32_t d0 = ((const int32_t *)(op->op_params))[4]; | ||
| 3340 | const int32_t d1 = ((const int32_t *)(op->op_params))[5]; | ||
| 3341 | |||
| 3342 | const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; | ||
| 3343 | |||
| 3344 | const int32_t N = op->src[1]->ne[is_2D ? 3 : 2]; | ||
| 3345 | const int32_t IC = op->src[1]->ne[is_2D ? 2 : 1]; | ||
| 3346 | const int32_t IH = is_2D ? op->src[1]->ne[1] : 1; | ||
| 3347 | const int32_t IW = op->src[1]->ne[0]; | ||
| 3348 | |||
| 3349 | const int32_t KH = is_2D ? op->src[0]->ne[1] : 1; | ||
| 3350 | const int32_t KW = op->src[0]->ne[0]; | ||
| 3351 | |||
| 3352 | const int32_t OH = is_2D ? op->ne[2] : 1; | ||
| 3353 | const int32_t OW = op->ne[1]; | ||
| 3354 | |||
| 3355 | const int32_t CHW = IC * KH * KW; | ||
| 3356 | |||
| 3357 | const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4; | ||
| 3358 | const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4; | ||
| 3359 | |||
| 3360 | ggml_metal_kargs_im2col args = { | ||
| 3361 | /*.ofs0 =*/ ofs0, | ||
| 3362 | /*.ofs1 =*/ ofs1, | ||
| 3363 | /*.IW =*/ IW, | ||
| 3364 | /*.IH =*/ IH, | ||
| 3365 | /*.CHW =*/ CHW, | ||
| 3366 | /*.s0 =*/ s0, | ||
| 3367 | /*.s1 =*/ s1, | ||
| 3368 | /*.p0 =*/ p0, | ||
| 3369 | /*.p1 =*/ p1, | ||
| 3370 | /*.d0 =*/ d0, | ||
| 3371 | /*.d1 =*/ d1, | ||
| 3372 | /*.N =*/ N, | ||
| 3373 | /*.KH =*/ KH, | ||
| 3374 | /*.KW =*/ KW, | ||
| 3375 | /*.KHW =*/ KH * KW, | ||
| 3376 | }; | ||
| 3377 | |||
| 3378 | auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); | ||
| 3379 | |||
| 3380 | GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 3381 | |||
| 3382 | const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); | ||
| 3383 | |||
| 3384 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3385 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3386 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); | ||
| 3387 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 3388 | |||
| 3389 | ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); | ||
| 3390 | |||
| 3391 | return 1; | ||
| 3392 | } | ||
| 3393 | |||
| 3394 | int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { | ||
| 3395 | ggml_tensor * op = ctx->node(idx); | ||
| 3396 | |||
| 3397 | ggml_metal_library_t lib = ctx->lib; | ||
| 3398 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3399 | |||
| 3400 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3401 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3402 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 3403 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 3404 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3405 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3406 | |||
| 3407 | GGML_ASSERT(ggml_is_contiguous(op->src[0])); | ||
| 3408 | GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); | ||
| 3409 | GGML_ASSERT(op->type == GGML_TYPE_F32); | ||
| 3410 | GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); | ||
| 3411 | |||
| 3412 | const int32_t s0 = ((const int32_t *) op->op_params)[0]; | ||
| 3413 | const int32_t s1 = ((const int32_t *) op->op_params)[1]; | ||
| 3414 | const int32_t p0 = ((const int32_t *) op->op_params)[2]; | ||
| 3415 | const int32_t p1 = ((const int32_t *) op->op_params)[3]; | ||
| 3416 | const int32_t d0 = ((const int32_t *) op->op_params)[4]; | ||
| 3417 | const int32_t d1 = ((const int32_t *) op->op_params)[5]; | ||
| 3418 | |||
| 3419 | ggml_metal_kargs_conv_2d args = { | ||
| 3420 | /*.nb00 =*/ nb00, | ||
| 3421 | /*.nb01 =*/ nb01, | ||
| 3422 | /*.nb02 =*/ nb02, | ||
| 3423 | /*.nb03 =*/ nb03, | ||
| 3424 | /*.nb10 =*/ nb10, | ||
| 3425 | /*.nb11 =*/ nb11, | ||
| 3426 | /*.nb12 =*/ nb12, | ||
| 3427 | /*.nb13 =*/ nb13, | ||
| 3428 | /*.nb0 =*/ nb0, | ||
| 3429 | /*.nb1 =*/ nb1, | ||
| 3430 | /*.nb2 =*/ nb2, | ||
| 3431 | /*.nb3 =*/ nb3, | ||
| 3432 | /*.IW =*/ ne10, | ||
| 3433 | /*.IH =*/ ne11, | ||
| 3434 | /*.KW =*/ ne00, | ||
| 3435 | /*.KH =*/ ne01, | ||
| 3436 | /*.IC =*/ ne02, | ||
| 3437 | /*.OC =*/ ne03, | ||
| 3438 | /*.OW =*/ ne0, | ||
| 3439 | /*.OH =*/ ne1, | ||
| 3440 | /*.N =*/ ne3, | ||
| 3441 | /*.s0 =*/ s0, | ||
| 3442 | /*.s1 =*/ s1, | ||
| 3443 | /*.p0 =*/ p0, | ||
| 3444 | /*.p1 =*/ p1, | ||
| 3445 | /*.d0 =*/ d0, | ||
| 3446 | /*.d1 =*/ d1, | ||
| 3447 | }; | ||
| 3448 | |||
| 3449 | auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op); | ||
| 3450 | |||
| 3451 | int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline); | ||
| 3452 | nth = std::min(nth, 256); | ||
| 3453 | nth = std::max(nth, 1); | ||
| 3454 | |||
| 3455 | const uint64_t n_out = ggml_nelements(op); | ||
| 3456 | |||
| 3457 | uint64_t tg = (n_out + nth - 1)/nth; | ||
| 3458 | tg = std::max<uint64_t>(tg, 1); | ||
| 3459 | tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max()); | ||
| 3460 | |||
| 3461 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3462 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3463 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3464 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 3465 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 3466 | |||
| 3467 | ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1); | ||
| 3468 | |||
| 3469 | return 1; | ||
| 3470 | } | ||
| 3471 | |||
| 3472 | int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { | ||
| 3473 | ggml_tensor * op = ctx->node(idx); | ||
| 3474 | |||
| 3475 | ggml_metal_library_t lib = ctx->lib; | ||
| 3476 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3477 | |||
| 3478 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3479 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3480 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 3481 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 3482 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3483 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3484 | |||
| 3485 | const int32_t s0 = ((const int32_t *)(op->op_params))[0]; | ||
| 3486 | |||
| 3487 | const int32_t IC = op->src[1]->ne[1]; | ||
| 3488 | const int32_t IL = op->src[1]->ne[0]; | ||
| 3489 | |||
| 3490 | const int32_t K = op->src[0]->ne[0]; | ||
| 3491 | |||
| 3492 | const int32_t OL = op->ne[0]; | ||
| 3493 | const int32_t OC = op->ne[1]; | ||
| 3494 | |||
| 3495 | ggml_metal_kargs_conv_transpose_1d args = { | ||
| 3496 | /*.IC =*/ IC, | ||
| 3497 | /*.IL =*/ IL, | ||
| 3498 | /*.K =*/ K, | ||
| 3499 | /*.s0 =*/ s0, | ||
| 3500 | /*.nb0 =*/ nb0, | ||
| 3501 | /*.nb1 =*/ nb1, | ||
| 3502 | }; | ||
| 3503 | |||
| 3504 | auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op); | ||
| 3505 | |||
| 3506 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3507 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3508 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3509 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 3510 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 3511 | |||
| 3512 | ggml_metal_encoder_dispatch_threadgroups(enc, OL, OC, 1, 1, 1, 1); | ||
| 3513 | |||
| 3514 | return 1; | ||
| 3515 | } | ||
| 3516 | |||
| 3517 | int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { | ||
| 3518 | ggml_tensor * op = ctx->node(idx); | ||
| 3519 | |||
| 3520 | ggml_metal_library_t lib = ctx->lib; | ||
| 3521 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3522 | |||
| 3523 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3524 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3525 | GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); | ||
| 3526 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 3527 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3528 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3529 | |||
| 3530 | const int32_t s0 = ((const int32_t *)(op->op_params))[0]; | ||
| 3531 | |||
| 3532 | const int32_t IC = op->src[1]->ne[2]; | ||
| 3533 | const int32_t IH = op->src[1]->ne[1]; | ||
| 3534 | const int32_t IW = op->src[1]->ne[0]; | ||
| 3535 | |||
| 3536 | const int32_t KH = op->src[0]->ne[1]; | ||
| 3537 | const int32_t KW = op->src[0]->ne[0]; | ||
| 3538 | |||
| 3539 | const int32_t OW = op->ne[0]; | ||
| 3540 | const int32_t OH = op->ne[1]; | ||
| 3541 | const int32_t OC = op->ne[2]; | ||
| 3542 | |||
| 3543 | ggml_metal_kargs_conv_transpose_2d args = { | ||
| 3544 | /*.IC =*/ IC, | ||
| 3545 | /*.IH =*/ IH, | ||
| 3546 | /*.IW =*/ IW, | ||
| 3547 | /*.KH =*/ KH, | ||
| 3548 | /*.KW =*/ KW, | ||
| 3549 | /*.OC =*/ OC, | ||
| 3550 | /*.s0 =*/ s0, | ||
| 3551 | /*.nb0 =*/ nb0, | ||
| 3552 | /*.nb1 =*/ nb1, | ||
| 3553 | /*.nb2 =*/ nb2, | ||
| 3554 | }; | ||
| 3555 | |||
| 3556 | auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op); | ||
| 3557 | |||
| 3558 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3559 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3560 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3561 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 3562 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); | ||
| 3563 | |||
| 3564 | // Metal requires buffer size to be multiple of 16 bytes | ||
| 3565 | const size_t smem = GGML_PAD(KW * KH * sizeof(float), 16); | ||
| 3566 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 3567 | |||
| 3568 | ggml_metal_encoder_dispatch_threadgroups(enc, OW, OH, OC, KW, KH, 1); | ||
| 3569 | |||
| 3570 | return 1; | ||
| 3571 | } | ||
| 3572 | |||
| 3573 | int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { | ||
| 3574 | ggml_tensor * op = ctx->node(idx); | ||
| 3575 | |||
| 3576 | ggml_metal_library_t lib = ctx->lib; | ||
| 3577 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3578 | |||
| 3579 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3580 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3581 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3582 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3583 | |||
| 3584 | const float sf0 = (float)ne0/op->src[0]->ne[0]; | ||
| 3585 | const float sf1 = (float)ne1/op->src[0]->ne[1]; | ||
| 3586 | const float sf2 = (float)ne2/op->src[0]->ne[2]; | ||
| 3587 | const float sf3 = (float)ne3/op->src[0]->ne[3]; | ||
| 3588 | |||
| 3589 | ggml_metal_kargs_upscale args = { | ||
| 3590 | /*.ne00 =*/ ne00, | ||
| 3591 | /*.ne01 =*/ ne01, | ||
| 3592 | /*.ne02 =*/ ne02, | ||
| 3593 | /*.ne03 =*/ ne03, | ||
| 3594 | /*.nb00 =*/ nb00, | ||
| 3595 | /*.nb01 =*/ nb01, | ||
| 3596 | /*.nb02 =*/ nb02, | ||
| 3597 | /*.nb03 =*/ nb03, | ||
| 3598 | /*.ne0 =*/ ne0, | ||
| 3599 | /*.ne1 =*/ ne1, | ||
| 3600 | /*.ne2 =*/ ne2, | ||
| 3601 | /*.ne3 =*/ ne3, | ||
| 3602 | /*.nb0 =*/ nb0, | ||
| 3603 | /*.nb1 =*/ nb1, | ||
| 3604 | /*.nb2 =*/ nb2, | ||
| 3605 | /*.nb3 =*/ nb3, | ||
| 3606 | /*.sf0 =*/ sf0, | ||
| 3607 | /*.sf1 =*/ sf1, | ||
| 3608 | /*.sf2 =*/ sf2, | ||
| 3609 | /*.sf3 =*/ sf3 | ||
| 3610 | }; | ||
| 3611 | |||
| 3612 | auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); | ||
| 3613 | |||
| 3614 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); | ||
| 3615 | |||
| 3616 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3617 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3618 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3619 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 3620 | |||
| 3621 | ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); | ||
| 3622 | |||
| 3623 | return 1; | ||
| 3624 | } | ||
| 3625 | |||
| 3626 | int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { | ||
| 3627 | ggml_tensor * op = ctx->node(idx); | ||
| 3628 | |||
| 3629 | ggml_metal_library_t lib = ctx->lib; | ||
| 3630 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3631 | |||
| 3632 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3633 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3634 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3635 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3636 | |||
| 3637 | ggml_metal_kargs_pad args = { | ||
| 3638 | /*.ne00 =*/ ne00, | ||
| 3639 | /*.ne01 =*/ ne01, | ||
| 3640 | /*.ne02 =*/ ne02, | ||
| 3641 | /*.ne03 =*/ ne03, | ||
| 3642 | /*.nb00 =*/ nb00, | ||
| 3643 | /*.nb01 =*/ nb01, | ||
| 3644 | /*.nb02 =*/ nb02, | ||
| 3645 | /*.nb03 =*/ nb03, | ||
| 3646 | /*.ne0 =*/ ne0, | ||
| 3647 | /*.ne1 =*/ ne1, | ||
| 3648 | /*.ne2 =*/ ne2, | ||
| 3649 | /*.ne3 =*/ ne3, | ||
| 3650 | /*.nb0 =*/ nb0, | ||
| 3651 | /*.nb1 =*/ nb1, | ||
| 3652 | /*.nb2 =*/ nb2, | ||
| 3653 | /*.nb3 =*/ nb3 | ||
| 3654 | }; | ||
| 3655 | |||
| 3656 | auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); | ||
| 3657 | |||
| 3658 | const int nth = std::min(1024, ne0); | ||
| 3659 | |||
| 3660 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3661 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3662 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3663 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 3664 | |||
| 3665 | ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); | ||
| 3666 | |||
| 3667 | return 1; | ||
| 3668 | } | ||
| 3669 | |||
| 3670 | int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { | ||
| 3671 | ggml_tensor * op = ctx->node(idx); | ||
| 3672 | |||
| 3673 | ggml_metal_library_t lib = ctx->lib; | ||
| 3674 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3675 | |||
| 3676 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3677 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3678 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3679 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3680 | |||
| 3681 | ggml_metal_kargs_pad_reflect_1d args = { | ||
| 3682 | /*.ne00 =*/ ne00, | ||
| 3683 | /*.ne01 =*/ ne01, | ||
| 3684 | /*.ne02 =*/ ne02, | ||
| 3685 | /*.ne03 =*/ ne03, | ||
| 3686 | /*.nb00 =*/ nb00, | ||
| 3687 | /*.nb01 =*/ nb01, | ||
| 3688 | /*.nb02 =*/ nb02, | ||
| 3689 | /*.nb03 =*/ nb03, | ||
| 3690 | /*.ne0 =*/ ne0, | ||
| 3691 | /*.ne1 =*/ ne1, | ||
| 3692 | /*.ne2 =*/ ne2, | ||
| 3693 | /*.ne3 =*/ ne3, | ||
| 3694 | /*.nb0 =*/ nb0, | ||
| 3695 | /*.nb1 =*/ nb1, | ||
| 3696 | /*.nb2 =*/ nb2, | ||
| 3697 | /*.nb3 =*/ nb3, | ||
| 3698 | /*.p0 =*/ ((const int32_t *)(op->op_params))[0], | ||
| 3699 | /*.p1 =*/ ((const int32_t *)(op->op_params))[1] | ||
| 3700 | }; | ||
| 3701 | |||
| 3702 | auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op); | ||
| 3703 | |||
| 3704 | const int nth = std::min(1024, ne0); | ||
| 3705 | |||
| 3706 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3707 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3708 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3709 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 3710 | |||
| 3711 | ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); | ||
| 3712 | |||
| 3713 | return 1; | ||
| 3714 | } | ||
| 3715 | |||
| 3716 | int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { | ||
| 3717 | ggml_tensor * op = ctx->node(idx); | ||
| 3718 | |||
| 3719 | ggml_metal_library_t lib = ctx->lib; | ||
| 3720 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3721 | |||
| 3722 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3723 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3724 | |||
| 3725 | float start; | ||
| 3726 | float step; | ||
| 3727 | |||
| 3728 | memcpy(&start, ((const int32_t *) op->op_params) + 0, sizeof(float)); | ||
| 3729 | memcpy(&step, ((const int32_t *) op->op_params) + 2, sizeof(float)); | ||
| 3730 | |||
| 3731 | ggml_metal_kargs_arange args = { | ||
| 3732 | /*.ne0 =*/ ne0, | ||
| 3733 | /*.start =*/ start, | ||
| 3734 | /*.step =*/ step | ||
| 3735 | }; | ||
| 3736 | |||
| 3737 | const int nth = std::min(1024, ne0); | ||
| 3738 | |||
| 3739 | auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op); | ||
| 3740 | |||
| 3741 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3742 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3743 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); | ||
| 3744 | |||
| 3745 | ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); | ||
| 3746 | |||
| 3747 | return 1; | ||
| 3748 | } | ||
| 3749 | |||
| 3750 | int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { | ||
| 3751 | ggml_tensor * op = ctx->node(idx); | ||
| 3752 | |||
| 3753 | ggml_metal_library_t lib = ctx->lib; | ||
| 3754 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3755 | |||
| 3756 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3757 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3758 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3759 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3760 | |||
| 3761 | const int dim = op->op_params[0]; | ||
| 3762 | const int max_period = op->op_params[1]; | ||
| 3763 | |||
| 3764 | ggml_metal_kargs_timestep_embedding args = { | ||
| 3765 | /*.nb1 =*/ nb1, | ||
| 3766 | /*.dim =*/ dim, | ||
| 3767 | /*.max_period =*/ max_period, | ||
| 3768 | }; | ||
| 3769 | |||
| 3770 | auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op); | ||
| 3771 | |||
| 3772 | const int nth = std::max(1, std::min(1024, dim/2)); | ||
| 3773 | |||
| 3774 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3775 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3776 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3777 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 3778 | |||
| 3779 | ggml_metal_encoder_dispatch_threadgroups(enc, ne00, 1, 1, nth, 1, 1); | ||
| 3780 | |||
| 3781 | return 1; | ||
| 3782 | } | ||
| 3783 | |||
| 3784 | int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { | ||
| 3785 | ggml_tensor * op = ctx->node(idx); | ||
| 3786 | |||
| 3787 | ggml_metal_library_t lib = ctx->lib; | ||
| 3788 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3789 | |||
| 3790 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3791 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3792 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3793 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3794 | |||
| 3795 | ggml_metal_kargs_argmax args = { | ||
| 3796 | /*.ne00 = */ ne00, | ||
| 3797 | /*.nb01 = */ nb01, | ||
| 3798 | }; | ||
| 3799 | |||
| 3800 | auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op); | ||
| 3801 | |||
| 3802 | const int64_t nrows = ggml_nrows(op->src[0]); | ||
| 3803 | |||
| 3804 | int nth = 32; // SIMD width | ||
| 3805 | while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { | ||
| 3806 | nth *= 2; | ||
| 3807 | } | ||
| 3808 | |||
| 3809 | const size_t smem = pipeline.smem; | ||
| 3810 | |||
| 3811 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3812 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3813 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 3814 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 3815 | |||
| 3816 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 3817 | |||
| 3818 | ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); | ||
| 3819 | |||
| 3820 | return 1; | ||
| 3821 | } | ||
| 3822 | |||
| 3823 | int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { | ||
| 3824 | ggml_tensor * op = ctx->node(idx); | ||
| 3825 | |||
| 3826 | ggml_metal_library_t lib = ctx->lib; | ||
| 3827 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3828 | |||
| 3829 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 3830 | |||
| 3831 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3832 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3833 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3834 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3835 | |||
| 3836 | auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); | ||
| 3837 | |||
| 3838 | // bitonic sort requires the number of elements to be power of 2 | ||
| 3839 | int nth = 1; | ||
| 3840 | while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 3841 | nth *= 2; | ||
| 3842 | } | ||
| 3843 | |||
| 3844 | const int npr = (ne00 + nth - 1)/nth; | ||
| 3845 | |||
| 3846 | // Metal kernels require the buffer size to be multiple of 16 bytes | ||
| 3847 | // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength | ||
| 3848 | const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); | ||
| 3849 | |||
| 3850 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 3851 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 3852 | |||
| 3853 | ggml_metal_buffer_id bid_tmp = bid_dst; | ||
| 3854 | bid_tmp.offs += ggml_nbytes(op); | ||
| 3855 | |||
| 3856 | if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { | ||
| 3857 | std::swap(bid_dst, bid_tmp); | ||
| 3858 | } | ||
| 3859 | |||
| 3860 | ggml_metal_kargs_argsort args = { | ||
| 3861 | /*.ne00 =*/ ne00, | ||
| 3862 | /*.ne01 =*/ ne01, | ||
| 3863 | /*.ne02 =*/ ne02, | ||
| 3864 | /*.ne03 =*/ ne03, | ||
| 3865 | /*.nb00 =*/ nb00, | ||
| 3866 | /*.nb01 =*/ nb01, | ||
| 3867 | /*.nb02 =*/ nb02, | ||
| 3868 | /*.nb03 =*/ nb03, | ||
| 3869 | /*.ne0 =*/ ne0, | ||
| 3870 | /*.ne1 =*/ ne1, | ||
| 3871 | /*.ne2 =*/ ne2, | ||
| 3872 | /*.ne3 =*/ ne3, | ||
| 3873 | /*.top_k =*/ nth, | ||
| 3874 | }; | ||
| 3875 | |||
| 3876 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3877 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3878 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 3879 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 3880 | |||
| 3881 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 3882 | |||
| 3883 | ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); | ||
| 3884 | |||
| 3885 | auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); | ||
| 3886 | |||
| 3887 | int len = nth; | ||
| 3888 | |||
| 3889 | while (len < ne00) { | ||
| 3890 | ggml_metal_op_concurrency_reset(ctx); | ||
| 3891 | |||
| 3892 | ggml_metal_kargs_argsort_merge args_merge = { | ||
| 3893 | /*.ne00 =*/ ne00, | ||
| 3894 | /*.ne01 =*/ ne01, | ||
| 3895 | /*.ne02 =*/ ne02, | ||
| 3896 | /*.ne03 =*/ ne03, | ||
| 3897 | /*.nb00 =*/ nb00, | ||
| 3898 | /*.nb01 =*/ nb01, | ||
| 3899 | /*.nb02 =*/ nb02, | ||
| 3900 | /*.nb03 =*/ nb03, | ||
| 3901 | /*.ne0 =*/ ne0, | ||
| 3902 | /*.ne1 =*/ ne1, | ||
| 3903 | /*.ne2 =*/ ne2, | ||
| 3904 | /*.ne3 =*/ ne3, | ||
| 3905 | /*.top_k =*/ ne00, | ||
| 3906 | /*.len =*/ len, | ||
| 3907 | }; | ||
| 3908 | |||
| 3909 | // merges per row | ||
| 3910 | const int nm = (ne00 + 2*len - 1) / (2*len); | ||
| 3911 | |||
| 3912 | const int nth = std::min(512, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge)); | ||
| 3913 | |||
| 3914 | ggml_metal_encoder_set_pipeline(enc, pipeline_merge); | ||
| 3915 | ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); | ||
| 3916 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 3917 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 3918 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); | ||
| 3919 | |||
| 3920 | ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); | ||
| 3921 | |||
| 3922 | std::swap(bid_dst, bid_tmp); | ||
| 3923 | |||
| 3924 | len <<= 1; | ||
| 3925 | } | ||
| 3926 | |||
| 3927 | return 1; | ||
| 3928 | } | ||
| 3929 | |||
| 3930 | int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { | ||
| 3931 | ggml_tensor * op = ctx->node(idx); | ||
| 3932 | |||
| 3933 | ggml_metal_library_t lib = ctx->lib; | ||
| 3934 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 3935 | |||
| 3936 | GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); | ||
| 3937 | |||
| 3938 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 3939 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 3940 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 3941 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 3942 | |||
| 3943 | auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op); | ||
| 3944 | |||
| 3945 | // bitonic sort requires the number of elements to be power of 2 | ||
| 3946 | int nth = 1; | ||
| 3947 | while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 3948 | nth *= 2; | ||
| 3949 | } | ||
| 3950 | |||
| 3951 | // blocks per row | ||
| 3952 | const int npr = (ne00 + nth - 1)/nth; | ||
| 3953 | |||
| 3954 | const size_t smem = GGML_PAD(nth*sizeof(int32_t), 16); | ||
| 3955 | |||
| 3956 | ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); | ||
| 3957 | ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); | ||
| 3958 | |||
| 3959 | ggml_metal_buffer_id bid_tmp = bid_dst; | ||
| 3960 | bid_tmp.offs += sizeof(int32_t)*ggml_nelements(op->src[0]); | ||
| 3961 | |||
| 3962 | if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { | ||
| 3963 | std::swap(bid_dst, bid_tmp); | ||
| 3964 | } | ||
| 3965 | |||
| 3966 | const int top_k = ne0; | ||
| 3967 | |||
| 3968 | ggml_metal_kargs_argsort args = { | ||
| 3969 | /*.ne00 =*/ ne00, | ||
| 3970 | /*.ne01 =*/ ne01, | ||
| 3971 | /*.ne02 =*/ ne02, | ||
| 3972 | /*.ne03 =*/ ne03, | ||
| 3973 | /*.nb00 =*/ nb00, | ||
| 3974 | /*.nb01 =*/ nb01, | ||
| 3975 | /*.nb02 =*/ nb02, | ||
| 3976 | /*.nb03 =*/ nb03, | ||
| 3977 | /*.ne0 =*/ ne0, | ||
| 3978 | /*.ne1 =*/ ne1, | ||
| 3979 | /*.ne2 =*/ ne2, | ||
| 3980 | /*.ne3 =*/ ne3, | ||
| 3981 | /*.top_k =*/ std::min(nth, top_k), // for each block, keep just the top_k indices | ||
| 3982 | }; | ||
| 3983 | |||
| 3984 | if (npr > 1) { | ||
| 3985 | args.ne0 = (npr - 1)*args.top_k + std::min(ne00 - (npr - 1)*nth, args.top_k); | ||
| 3986 | } | ||
| 3987 | |||
| 3988 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 3989 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 3990 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 3991 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 3992 | |||
| 3993 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 3994 | |||
| 3995 | ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); | ||
| 3996 | |||
| 3997 | auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op); | ||
| 3998 | |||
| 3999 | int len = args.top_k; | ||
| 4000 | |||
| 4001 | while (len < args.ne0) { | ||
| 4002 | ggml_metal_op_concurrency_reset(ctx); | ||
| 4003 | |||
| 4004 | // merges per row | ||
| 4005 | const int nm = (args.ne0 + 2*len - 1) / (2*len); | ||
| 4006 | |||
| 4007 | const int nth = std::min(512, std::min(len, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_merge))); | ||
| 4008 | |||
| 4009 | ggml_metal_kargs_argsort_merge args_merge = { | ||
| 4010 | /*.ne00 =*/ ne00, | ||
| 4011 | /*.ne01 =*/ ne01, | ||
| 4012 | /*.ne02 =*/ ne02, | ||
| 4013 | /*.ne03 =*/ ne03, | ||
| 4014 | /*.nb00 =*/ nb00, | ||
| 4015 | /*.nb01 =*/ nb01, | ||
| 4016 | /*.nb02 =*/ nb02, | ||
| 4017 | /*.nb03 =*/ nb03, | ||
| 4018 | /*.ne0 =*/ args.ne0, | ||
| 4019 | /*.ne1 =*/ ne1, | ||
| 4020 | /*.ne2 =*/ ne2, | ||
| 4021 | /*.ne3 =*/ ne3, | ||
| 4022 | /*.top_k =*/ nm == 1 ? top_k : args.ne0, // the final merge outputs top_k elements | ||
| 4023 | /*.len =*/ len, | ||
| 4024 | }; | ||
| 4025 | |||
| 4026 | ggml_metal_encoder_set_pipeline(enc, pipeline_merge); | ||
| 4027 | ggml_metal_encoder_set_bytes (enc, &args_merge, sizeof(args_merge), 0); | ||
| 4028 | ggml_metal_encoder_set_buffer (enc, bid_src0, 1); | ||
| 4029 | ggml_metal_encoder_set_buffer (enc, bid_dst, 2); | ||
| 4030 | ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); | ||
| 4031 | |||
| 4032 | ggml_metal_encoder_dispatch_threadgroups(enc, nm*ne01, ne02, ne03, nth, 1, 1); | ||
| 4033 | |||
| 4034 | std::swap(bid_dst, bid_tmp); | ||
| 4035 | |||
| 4036 | len <<= 1; | ||
| 4037 | } | ||
| 4038 | |||
| 4039 | return 1; | ||
| 4040 | } | ||
| 4041 | |||
| 4042 | int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { | ||
| 4043 | ggml_tensor * op = ctx->node(idx); | ||
| 4044 | |||
| 4045 | ggml_metal_library_t lib = ctx->lib; | ||
| 4046 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 4047 | |||
| 4048 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 4049 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 4050 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 4051 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 4052 | |||
| 4053 | ggml_metal_kargs_tri args = { | ||
| 4054 | /*.ne00 =*/ ne00, | ||
| 4055 | /*.ne01 =*/ ne01, | ||
| 4056 | /*.ne02 =*/ ne02, | ||
| 4057 | /*.ne03 =*/ ne03, | ||
| 4058 | /*.nb00 =*/ nb00, | ||
| 4059 | /*.nb01 =*/ nb01, | ||
| 4060 | /*.nb02 =*/ nb02, | ||
| 4061 | /*.nb03 =*/ nb03, | ||
| 4062 | /*.ne0 =*/ ne0, | ||
| 4063 | /*.ne1 =*/ ne1, | ||
| 4064 | /*.ne2 =*/ ne2, | ||
| 4065 | /*.ne3 =*/ ne3, | ||
| 4066 | /*.nb0 =*/ nb0, | ||
| 4067 | /*.nb1 =*/ nb1, | ||
| 4068 | /*.nb2 =*/ nb2, | ||
| 4069 | /*.nb3 =*/ nb3, | ||
| 4070 | }; | ||
| 4071 | |||
| 4072 | auto pipeline = ggml_metal_library_get_pipeline_tri(lib, op); | ||
| 4073 | |||
| 4074 | int nth = 32; // SIMD width | ||
| 4075 | |||
| 4076 | while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { | ||
| 4077 | nth *= 2; | ||
| 4078 | } | ||
| 4079 | |||
| 4080 | nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 4081 | nth = std::min(nth, ne00); | ||
| 4082 | |||
| 4083 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 4084 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); | ||
| 4085 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 4086 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); | ||
| 4087 | |||
| 4088 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 4089 | |||
| 4090 | return 1; | ||
| 4091 | } | ||
| 4092 | |||
| 4093 | int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { | ||
| 4094 | ggml_tensor * op = ctx->node(idx); | ||
| 4095 | |||
| 4096 | ggml_metal_library_t lib = ctx->lib; | ||
| 4097 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 4098 | |||
| 4099 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 4100 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 4101 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 4102 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 4103 | |||
| 4104 | auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); | ||
| 4105 | |||
| 4106 | const int64_t np = ggml_nelements(op->src[0]); | ||
| 4107 | ggml_metal_kargs_opt_step_adamw args = { | ||
| 4108 | /*.np =*/ np, | ||
| 4109 | }; | ||
| 4110 | |||
| 4111 | int ida = 0; | ||
| 4112 | |||
| 4113 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 4114 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); | ||
| 4115 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); | ||
| 4116 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); | ||
| 4117 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); | ||
| 4118 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); | ||
| 4119 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); | ||
| 4120 | |||
| 4121 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); | ||
| 4122 | const int64_t n = (np + nth - 1) / nth; | ||
| 4123 | |||
| 4124 | ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); | ||
| 4125 | |||
| 4126 | return 1; | ||
| 4127 | } | ||
| 4128 | |||
| 4129 | int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { | ||
| 4130 | ggml_tensor * op = ctx->node(idx); | ||
| 4131 | |||
| 4132 | ggml_metal_library_t lib = ctx->lib; | ||
| 4133 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 4134 | |||
| 4135 | GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||
| 4136 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 4137 | GGML_TENSOR_LOCALS( int32_t, ne, op, ne); | ||
| 4138 | GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); | ||
| 4139 | |||
| 4140 | auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); | ||
| 4141 | |||
| 4142 | const int64_t np = ggml_nelements(op->src[0]); | ||
| 4143 | ggml_metal_kargs_opt_step_sgd args = { | ||
| 4144 | /*.np =*/ np, | ||
| 4145 | }; | ||
| 4146 | |||
| 4147 | int ida = 0; | ||
| 4148 | |||
| 4149 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 4150 | ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); | ||
| 4151 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); | ||
| 4152 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); | ||
| 4153 | ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); | ||
| 4154 | |||
| 4155 | const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); | ||
| 4156 | const int64_t n = (np + nth - 1) / nth; | ||
| 4157 | |||
| 4158 | ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); | ||
| 4159 | |||
| 4160 | return 1; | ||
| 4161 | } | ||
| 4162 | |||
| 4163 | int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) { | ||
| 4164 | ggml_tensor * op = ctx->node(idx); | ||
| 4165 | |||
| 4166 | ggml_metal_library_t lib = ctx->lib; | ||
| 4167 | ggml_metal_encoder_t enc = ctx->enc; | ||
| 4168 | |||
| 4169 | GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); | ||
| 4170 | GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||
| 4171 | GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); | ||
| 4172 | |||
| 4173 | { | ||
| 4174 | ggml_metal_kargs_memset args = { /*.val =*/ 0 }; | ||
| 4175 | |||
| 4176 | auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op); | ||
| 4177 | |||
| 4178 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 4179 | ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| 4180 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1); | ||
| 4181 | |||
| 4182 | ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); | ||
| 4183 | } | ||
| 4184 | |||
| 4185 | ggml_metal_op_concurrency_reset(ctx); | ||
| 4186 | |||
| 4187 | { | ||
| 4188 | ggml_metal_kargs_count_equal args = { | ||
| 4189 | /*.ne00 =*/ ne00, | ||
| 4190 | /*.ne01 =*/ ne01, | ||
| 4191 | /*.ne02 =*/ ne02, | ||
| 4192 | /*.ne03 =*/ ne03, | ||
| 4193 | /*.nb00 =*/ nb00, | ||
| 4194 | /*.nb01 =*/ nb01, | ||
| 4195 | /*.nb02 =*/ nb02, | ||
| 4196 | /*.nb03 =*/ nb03, | ||
| 4197 | /*.nb10 =*/ nb10, | ||
| 4198 | /*.nb11 =*/ nb11, | ||
| 4199 | /*.nb12 =*/ nb12, | ||
| 4200 | /*.nb13 =*/ nb13, | ||
| 4201 | }; | ||
| 4202 | |||
| 4203 | auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op); | ||
| 4204 | |||
| 4205 | const size_t smem = pipeline.smem; | ||
| 4206 | |||
| 4207 | const int nth = 32*pipeline.nsg; | ||
| 4208 | |||
| 4209 | GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); | ||
| 4210 | |||
| 4211 | ggml_metal_encoder_set_pipeline(enc, pipeline); | ||
| 4212 | ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); | ||
| 4213 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); | ||
| 4214 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); | ||
| 4215 | ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); | ||
| 4216 | |||
| 4217 | ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||
| 4218 | ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); | ||
| 4219 | } | ||
| 4220 | |||
| 4221 | return 1; | ||
| 4222 | } | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h new file mode 100644 index 0000000..29456d7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal-ops.h | |||
| @@ -0,0 +1,93 @@ | |||
| 1 | #pragma once | ||
| 2 | |||
| 3 | #include "ggml-metal-device.h" | ||
| 4 | |||
| 5 | #ifdef __cplusplus | ||
| 6 | extern "C" { | ||
| 7 | #endif | ||
| 8 | |||
| 9 | typedef struct ggml_metal_op * ggml_metal_op_t; | ||
| 10 | |||
| 11 | ggml_metal_op_t ggml_metal_op_init( | ||
| 12 | ggml_metal_device_t dev, | ||
| 13 | ggml_metal_cmd_buf_t cmd_buf, | ||
| 14 | struct ggml_cgraph * gf, | ||
| 15 | int idx_start, | ||
| 16 | int idx_end, | ||
| 17 | bool use_fusion, | ||
| 18 | bool use_concurrency, | ||
| 19 | bool use_capture, | ||
| 20 | int debug_graph, | ||
| 21 | int debug_fusion); | ||
| 22 | |||
| 23 | void ggml_metal_op_free(ggml_metal_op_t ctx); | ||
| 24 | |||
| 25 | int ggml_metal_op_n_nodes(ggml_metal_op_t ctx); | ||
| 26 | |||
| 27 | int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx); | ||
| 28 | |||
| 29 | // | ||
| 30 | // available ops: | ||
| 31 | // | ||
| 32 | |||
| 33 | // tokens per expert | ||
| 34 | size_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op); | ||
| 35 | |||
| 36 | // id map [n_tokens, n_expert] | ||
| 37 | size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); | ||
| 38 | |||
| 39 | // return true if we should use the FA vector kernel for this op | ||
| 40 | bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); | ||
| 41 | |||
| 42 | size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); | ||
| 43 | size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); | ||
| 44 | size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); | ||
| 45 | |||
| 46 | int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); | ||
| 47 | int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); | ||
| 48 | int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); | ||
| 49 | int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); | ||
| 50 | int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); | ||
| 51 | int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); | ||
| 52 | int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); | ||
| 53 | int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); | ||
| 54 | int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); | ||
| 55 | int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); | ||
| 56 | int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx); | ||
| 57 | int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); | ||
| 58 | int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); | ||
| 59 | int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); | ||
| 60 | int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); | ||
| 61 | int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); | ||
| 62 | int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); | ||
| 63 | int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); | ||
| 64 | int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); | ||
| 65 | int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); | ||
| 66 | int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); | ||
| 67 | int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx); | ||
| 68 | int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx); | ||
| 69 | int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx); | ||
| 70 | int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); | ||
| 71 | int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); | ||
| 72 | int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); | ||
| 73 | int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); | ||
| 74 | int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); | ||
| 75 | int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); | ||
| 76 | int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); | ||
| 77 | int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); | ||
| 78 | int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); | ||
| 79 | int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); | ||
| 80 | int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); | ||
| 81 | int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); | ||
| 82 | int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); | ||
| 83 | int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); | ||
| 84 | int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); | ||
| 85 | int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); | ||
| 86 | int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); | ||
| 87 | int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); | ||
| 88 | int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); | ||
| 89 | int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx); | ||
| 90 | |||
| 91 | #ifdef __cplusplus | ||
| 92 | } | ||
| 93 | #endif | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp b/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp new file mode 100644 index 0000000..1c70536 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal.cpp | |||
| @@ -0,0 +1,937 @@ | |||
| 1 | #include "ggml-metal.h" | ||
| 2 | |||
| 3 | #include "ggml-impl.h" | ||
| 4 | #include "ggml-backend-impl.h" | ||
| 5 | |||
| 6 | #include "ggml-metal-device.h" | ||
| 7 | #include "ggml-metal-context.h" | ||
| 8 | #include "ggml-metal-ops.h" | ||
| 9 | |||
| 10 | #include <mutex> | ||
| 11 | #include <string> | ||
| 12 | |||
| 13 | #define GGML_METAL_NAME "MTL" | ||
| 14 | #define GGML_METAL_MAX_DEVICES 16 | ||
| 15 | |||
| 16 | // number of Metal devices | ||
| 17 | // note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices | ||
| 18 | static int g_devices = 1; | ||
| 19 | |||
| 20 | //////////////////////////////////////////////////////////////////////////////// | ||
| 21 | // backend interface | ||
| 22 | //////////////////////////////////////////////////////////////////////////////// | ||
| 23 | |||
| 24 | // shared buffer | ||
| 25 | |||
| 26 | static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t buffer) { | ||
| 27 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 28 | |||
| 29 | GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); | ||
| 30 | |||
| 31 | ggml_metal_buffer_free(ctx); | ||
| 32 | } | ||
| 33 | |||
| 34 | static void * ggml_backend_metal_buffer_shared_get_base(ggml_backend_buffer_t buffer) { | ||
| 35 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 36 | |||
| 37 | GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); | ||
| 38 | |||
| 39 | return ggml_metal_buffer_get_base(ctx); | ||
| 40 | } | ||
| 41 | |||
| 42 | static void ggml_backend_metal_buffer_shared_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { | ||
| 43 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 44 | |||
| 45 | GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); | ||
| 46 | |||
| 47 | ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size); | ||
| 48 | } | ||
| 49 | |||
| 50 | static void ggml_backend_metal_buffer_shared_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { | ||
| 51 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 52 | |||
| 53 | GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); | ||
| 54 | |||
| 55 | ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size); | ||
| 56 | } | ||
| 57 | |||
| 58 | static void ggml_backend_metal_buffer_shared_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { | ||
| 59 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 60 | |||
| 61 | GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); | ||
| 62 | |||
| 63 | ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size); | ||
| 64 | } | ||
| 65 | |||
| 66 | static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { | ||
| 67 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 68 | |||
| 69 | GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); | ||
| 70 | |||
| 71 | GGML_UNUSED(buffer); | ||
| 72 | GGML_UNUSED(src); | ||
| 73 | GGML_UNUSED(dst); | ||
| 74 | |||
| 75 | return false; | ||
| 76 | } | ||
| 77 | |||
| 78 | static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { | ||
| 79 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 80 | |||
| 81 | GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); | ||
| 82 | |||
| 83 | ggml_metal_buffer_clear(ctx, value); | ||
| 84 | } | ||
| 85 | |||
| 86 | static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { | ||
| 87 | /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, | ||
| 88 | /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, | ||
| 89 | /* .init_tensor = */ NULL, | ||
| 90 | /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, | ||
| 91 | /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, | ||
| 92 | /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, | ||
| 93 | /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, | ||
| 94 | /* .clear = */ ggml_backend_metal_buffer_shared_clear, | ||
| 95 | /* .reset = */ NULL, | ||
| 96 | }; | ||
| 97 | |||
| 98 | // private buffer | ||
| 99 | |||
| 100 | static void ggml_backend_metal_buffer_private_free_buffer(ggml_backend_buffer_t buffer) { | ||
| 101 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 102 | |||
| 103 | GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); | ||
| 104 | |||
| 105 | ggml_metal_buffer_free(ctx); | ||
| 106 | } | ||
| 107 | |||
| 108 | static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { | ||
| 109 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 110 | |||
| 111 | GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); | ||
| 112 | |||
| 113 | return ggml_metal_buffer_get_base(ctx); | ||
| 114 | } | ||
| 115 | |||
| 116 | static void ggml_backend_metal_buffer_private_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { | ||
| 117 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 118 | |||
| 119 | GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); | ||
| 120 | |||
| 121 | ggml_metal_buffer_memset_tensor(ctx, tensor, value, offset, size); | ||
| 122 | } | ||
| 123 | |||
| 124 | static void ggml_backend_metal_buffer_private_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { | ||
| 125 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 126 | |||
| 127 | GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); | ||
| 128 | |||
| 129 | ggml_metal_buffer_set_tensor(ctx, tensor, data, offset, size); | ||
| 130 | } | ||
| 131 | |||
| 132 | static void ggml_backend_metal_buffer_private_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { | ||
| 133 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 134 | |||
| 135 | GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); | ||
| 136 | |||
| 137 | ggml_metal_buffer_get_tensor(ctx, tensor, data, offset, size); | ||
| 138 | } | ||
| 139 | |||
| 140 | static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { | ||
| 141 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 142 | |||
| 143 | GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); | ||
| 144 | |||
| 145 | GGML_UNUSED(buffer); | ||
| 146 | GGML_UNUSED(src); | ||
| 147 | GGML_UNUSED(dst); | ||
| 148 | |||
| 149 | return false; | ||
| 150 | } | ||
| 151 | |||
| 152 | static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { | ||
| 153 | ggml_metal_buffer_t ctx = (ggml_metal_buffer_t)buffer->context; | ||
| 154 | |||
| 155 | GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); | ||
| 156 | |||
| 157 | ggml_metal_buffer_clear(ctx, value); | ||
| 158 | } | ||
| 159 | |||
| 160 | static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { | ||
| 161 | /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, | ||
| 162 | /* .get_base = */ ggml_backend_metal_buffer_private_get_base, | ||
| 163 | /* .init_tensor = */ NULL, | ||
| 164 | /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, | ||
| 165 | /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, | ||
| 166 | /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, | ||
| 167 | /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, | ||
| 168 | /* .clear = */ ggml_backend_metal_buffer_private_clear, | ||
| 169 | /* .reset = */ NULL, | ||
| 170 | }; | ||
| 171 | |||
| 172 | static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { | ||
| 173 | return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer || | ||
| 174 | buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer; | ||
| 175 | } | ||
| 176 | |||
| 177 | // | ||
| 178 | // buffer types | ||
| 179 | // | ||
| 180 | |||
| 181 | struct ggml_backend_metal_buffer_type { | ||
| 182 | int device; | ||
| 183 | std::string name; | ||
| 184 | }; | ||
| 185 | |||
| 186 | struct ggml_backend_metal_buffer_type_deleter { | ||
| 187 | void operator()(ggml_backend_metal_buffer_type * ctx) const { | ||
| 188 | delete ctx; | ||
| 189 | } | ||
| 190 | }; | ||
| 191 | |||
| 192 | typedef std::unique_ptr<ggml_backend_metal_buffer_type, ggml_backend_metal_buffer_type_deleter> ggml_backend_metal_buffer_type_ptr; | ||
| 193 | |||
| 194 | // common method for allocating shread or private Metal buffers | ||
| 195 | static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { | ||
| 196 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; | ||
| 197 | ggml_metal_buffer_t res = ggml_metal_buffer_init(ctx_dev, size, shared); | ||
| 198 | |||
| 199 | ggml_backend_buffer_i buf_i = ggml_metal_buffer_is_shared(res) | ||
| 200 | ? ggml_backend_metal_buffer_shared_i | ||
| 201 | : ggml_backend_metal_buffer_private_i; | ||
| 202 | |||
| 203 | return ggml_backend_buffer_init(buft, buf_i, res, size); | ||
| 204 | } | ||
| 205 | |||
| 206 | static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { | ||
| 207 | size_t res = ggml_nbytes(tensor); | ||
| 208 | |||
| 209 | // some operations require additional memory for fleeting data: | ||
| 210 | switch (tensor->op) { | ||
| 211 | case GGML_OP_MUL_MAT_ID: | ||
| 212 | { | ||
| 213 | res += ggml_metal_op_mul_mat_id_extra_tpe(tensor); | ||
| 214 | res += ggml_metal_op_mul_mat_id_extra_ids(tensor); | ||
| 215 | } break; | ||
| 216 | case GGML_OP_FLASH_ATTN_EXT: | ||
| 217 | { | ||
| 218 | res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); | ||
| 219 | res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); | ||
| 220 | res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); | ||
| 221 | } break; | ||
| 222 | case GGML_OP_CUMSUM: | ||
| 223 | case GGML_OP_ARGSORT: | ||
| 224 | { | ||
| 225 | res *= 2; | ||
| 226 | } break; | ||
| 227 | case GGML_OP_TOP_K: | ||
| 228 | { | ||
| 229 | res = 2*sizeof(int32_t)*ggml_nelements(tensor->src[0]); | ||
| 230 | } break; | ||
| 231 | default: | ||
| 232 | break; | ||
| 233 | } | ||
| 234 | |||
| 235 | return res; | ||
| 236 | |||
| 237 | GGML_UNUSED(buft); | ||
| 238 | } | ||
| 239 | |||
| 240 | // default (shared) buffer type | ||
| 241 | |||
| 242 | static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { | ||
| 243 | ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; | ||
| 244 | |||
| 245 | return ctx->name.c_str(); | ||
| 246 | } | ||
| 247 | |||
| 248 | static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { | ||
| 249 | return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); | ||
| 250 | } | ||
| 251 | |||
| 252 | static size_t ggml_backend_metal_buffer_type_shared_get_alignment(ggml_backend_buffer_type_t buft) { | ||
| 253 | return 32; | ||
| 254 | |||
| 255 | GGML_UNUSED(buft); | ||
| 256 | } | ||
| 257 | |||
| 258 | static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_buffer_type_t buft) { | ||
| 259 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; | ||
| 260 | |||
| 261 | return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; | ||
| 262 | } | ||
| 263 | |||
| 264 | static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { | ||
| 265 | return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); | ||
| 266 | } | ||
| 267 | |||
| 268 | static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) { | ||
| 269 | return false; | ||
| 270 | |||
| 271 | GGML_UNUSED(buft); | ||
| 272 | } | ||
| 273 | |||
| 274 | static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) { | ||
| 275 | static std::mutex mutex; | ||
| 276 | std::lock_guard<std::mutex> lock(mutex); | ||
| 277 | |||
| 278 | static std::vector<ggml_backend_buffer_type> bufts; | ||
| 279 | static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs; | ||
| 280 | |||
| 281 | static bool initialized = false; | ||
| 282 | if (!initialized) { | ||
| 283 | bufts.reserve(g_devices); | ||
| 284 | ctxs.reserve(g_devices); | ||
| 285 | |||
| 286 | for (int i = 0; i < g_devices; ++i) { | ||
| 287 | ggml_backend_metal_buffer_type * raw_ctx = | ||
| 288 | new ggml_backend_metal_buffer_type { | ||
| 289 | /* .device = */ i, | ||
| 290 | /* .name = */ GGML_METAL_NAME + std::to_string(i), | ||
| 291 | }; | ||
| 292 | ctxs.emplace_back(raw_ctx); | ||
| 293 | |||
| 294 | ggml_backend_buffer_type buft = { | ||
| 295 | /* .iface = */ { | ||
| 296 | /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, | ||
| 297 | /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, | ||
| 298 | /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, | ||
| 299 | /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, | ||
| 300 | /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, | ||
| 301 | /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, | ||
| 302 | }, | ||
| 303 | /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), | ||
| 304 | /* .context = */ raw_ctx, | ||
| 305 | }; | ||
| 306 | |||
| 307 | bufts.emplace_back(buft); | ||
| 308 | } | ||
| 309 | |||
| 310 | initialized = true; | ||
| 311 | } | ||
| 312 | |||
| 313 | return &bufts[device]; | ||
| 314 | } | ||
| 315 | |||
| 316 | // default (private) buffer type | ||
| 317 | |||
| 318 | static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { | ||
| 319 | ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; | ||
| 320 | |||
| 321 | return ctx->name.c_str(); | ||
| 322 | } | ||
| 323 | |||
| 324 | static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { | ||
| 325 | return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, false); | ||
| 326 | } | ||
| 327 | |||
| 328 | static size_t ggml_backend_metal_buffer_type_private_get_alignment(ggml_backend_buffer_type_t buft) { | ||
| 329 | return 32; | ||
| 330 | |||
| 331 | GGML_UNUSED(buft); | ||
| 332 | } | ||
| 333 | |||
| 334 | static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_buffer_type_t buft) { | ||
| 335 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; | ||
| 336 | |||
| 337 | return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; | ||
| 338 | } | ||
| 339 | |||
| 340 | static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { | ||
| 341 | return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); | ||
| 342 | } | ||
| 343 | |||
| 344 | static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) { | ||
| 345 | return false; | ||
| 346 | |||
| 347 | GGML_UNUSED(buft); | ||
| 348 | } | ||
| 349 | |||
| 350 | static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) { | ||
| 351 | static std::mutex mutex; | ||
| 352 | std::lock_guard<std::mutex> lock(mutex); | ||
| 353 | |||
| 354 | static std::vector<ggml_backend_buffer_type> bufts; | ||
| 355 | static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs; | ||
| 356 | |||
| 357 | static bool initialized = false; | ||
| 358 | if (!initialized) { | ||
| 359 | bufts.reserve(g_devices); | ||
| 360 | ctxs.reserve(g_devices); | ||
| 361 | |||
| 362 | for (int i = 0; i < g_devices; ++i) { | ||
| 363 | ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ | ||
| 364 | /* .device = */ i, | ||
| 365 | /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Private" | ||
| 366 | }; | ||
| 367 | ctxs.emplace_back(raw_ctx); | ||
| 368 | |||
| 369 | ggml_backend_buffer_type buft = { | ||
| 370 | /* .iface = */ { | ||
| 371 | /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, | ||
| 372 | /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, | ||
| 373 | /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, | ||
| 374 | /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, | ||
| 375 | /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, | ||
| 376 | /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, | ||
| 377 | }, | ||
| 378 | /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), | ||
| 379 | /* .context = */ raw_ctx, | ||
| 380 | }; | ||
| 381 | |||
| 382 | bufts.emplace_back(buft); | ||
| 383 | } | ||
| 384 | |||
| 385 | initialized = true; | ||
| 386 | } | ||
| 387 | |||
| 388 | return &bufts[device]; | ||
| 389 | } | ||
| 390 | |||
| 391 | // mapped buffer type | ||
| 392 | |||
| 393 | static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { | ||
| 394 | ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; | ||
| 395 | |||
| 396 | return ctx->name.c_str(); | ||
| 397 | } | ||
| 398 | |||
| 399 | static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { | ||
| 400 | // for mapped buffers, prefer shared memory | ||
| 401 | return ggml_backend_metal_buffer_type_alloc_buffer(buft, size, true); | ||
| 402 | } | ||
| 403 | |||
| 404 | static size_t ggml_backend_metal_buffer_type_mapped_get_alignment(ggml_backend_buffer_type_t buft) { | ||
| 405 | return 32; | ||
| 406 | |||
| 407 | GGML_UNUSED(buft); | ||
| 408 | } | ||
| 409 | |||
| 410 | static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_buffer_type_t buft) { | ||
| 411 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; | ||
| 412 | |||
| 413 | return ggml_metal_device_get_props(ctx_dev)->max_buffer_size; | ||
| 414 | } | ||
| 415 | |||
| 416 | static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { | ||
| 417 | return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor); | ||
| 418 | } | ||
| 419 | |||
| 420 | static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) { | ||
| 421 | return false; | ||
| 422 | |||
| 423 | GGML_UNUSED(buft); | ||
| 424 | } | ||
| 425 | |||
| 426 | static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) { | ||
| 427 | static std::mutex mutex; | ||
| 428 | std::lock_guard<std::mutex> lock(mutex); | ||
| 429 | |||
| 430 | static std::vector<ggml_backend_buffer_type> bufts; | ||
| 431 | static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs; | ||
| 432 | |||
| 433 | static bool initialized = false; | ||
| 434 | if (!initialized) { | ||
| 435 | bufts.reserve(g_devices); | ||
| 436 | ctxs.reserve(g_devices); | ||
| 437 | |||
| 438 | for (int i = 0; i < g_devices; ++i) { | ||
| 439 | ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ | ||
| 440 | /* .device = */ i, | ||
| 441 | /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped" | ||
| 442 | }; | ||
| 443 | ctxs.emplace_back(raw_ctx); | ||
| 444 | |||
| 445 | // note: not obvious, but this buffer type still needs to implement .alloc_buffer: | ||
| 446 | // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 | ||
| 447 | ggml_backend_buffer_type buft = { | ||
| 448 | /* .iface = */ { | ||
| 449 | /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, | ||
| 450 | /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, | ||
| 451 | /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, | ||
| 452 | /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, | ||
| 453 | /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, | ||
| 454 | /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, | ||
| 455 | }, | ||
| 456 | /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), | ||
| 457 | /* .context = */ raw_ctx, | ||
| 458 | }; | ||
| 459 | |||
| 460 | bufts.emplace_back(buft); | ||
| 461 | } | ||
| 462 | |||
| 463 | initialized = true; | ||
| 464 | } | ||
| 465 | |||
| 466 | return &bufts[device]; | ||
| 467 | } | ||
| 468 | |||
| 469 | // backend | ||
| 470 | |||
| 471 | static const char * ggml_backend_metal_name(ggml_backend_t backend) { | ||
| 472 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 473 | |||
| 474 | return ggml_metal_get_name(ctx); | ||
| 475 | } | ||
| 476 | |||
| 477 | static void ggml_backend_metal_free(ggml_backend_t backend) { | ||
| 478 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 479 | |||
| 480 | // wait for any ongoing async operations to finish | ||
| 481 | ggml_metal_synchronize(ctx); | ||
| 482 | |||
| 483 | ggml_metal_free(ctx); | ||
| 484 | |||
| 485 | free(backend); | ||
| 486 | } | ||
| 487 | |||
| 488 | static void ggml_backend_metal_synchronize(ggml_backend_t backend) { | ||
| 489 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 490 | |||
| 491 | ggml_metal_synchronize(ctx); | ||
| 492 | } | ||
| 493 | |||
| 494 | static void ggml_backend_metal_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { | ||
| 495 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 496 | |||
| 497 | ggml_metal_set_tensor_async(ctx, tensor, data, offset, size); | ||
| 498 | } | ||
| 499 | |||
| 500 | static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { | ||
| 501 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 502 | |||
| 503 | ggml_metal_get_tensor_async(ctx, tensor, data, offset, size); | ||
| 504 | } | ||
| 505 | |||
| 506 | static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { | ||
| 507 | if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) { | ||
| 508 | return false; | ||
| 509 | } | ||
| 510 | |||
| 511 | if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) { | ||
| 512 | return false; | ||
| 513 | } | ||
| 514 | |||
| 515 | ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context; | ||
| 516 | ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context; | ||
| 517 | |||
| 518 | //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; | ||
| 519 | //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; | ||
| 520 | |||
| 521 | //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context; | ||
| 522 | //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context; | ||
| 523 | |||
| 524 | return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst); | ||
| 525 | } | ||
| 526 | |||
| 527 | static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { | ||
| 528 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 529 | |||
| 530 | return ggml_metal_graph_compute(ctx, cgraph); | ||
| 531 | } | ||
| 532 | |||
| 533 | static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) { | ||
| 534 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 535 | ggml_metal_event_t ev = (ggml_metal_event_t)event->context; | ||
| 536 | |||
| 537 | ggml_metal_event_record(ctx, ev); | ||
| 538 | } | ||
| 539 | |||
| 540 | static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { | ||
| 541 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 542 | ggml_metal_event_t ev = (ggml_metal_event_t)event->context; | ||
| 543 | |||
| 544 | ggml_metal_event_wait(ctx, ev); | ||
| 545 | } | ||
| 546 | |||
| 547 | static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { | ||
| 548 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 549 | |||
| 550 | ggml_metal_graph_optimize(ctx, cgraph); | ||
| 551 | } | ||
| 552 | |||
| 553 | static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { | ||
| 554 | GGML_ASSERT(ggml_backend_is_metal(backend)); | ||
| 555 | |||
| 556 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 557 | |||
| 558 | ggml_metal_set_n_cb(ctx, n_cb); | ||
| 559 | } | ||
| 560 | |||
| 561 | static ggml_backend_i ggml_backend_metal_i = { | ||
| 562 | /* .get_name = */ ggml_backend_metal_name, | ||
| 563 | /* .free = */ ggml_backend_metal_free, | ||
| 564 | /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, | ||
| 565 | /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, | ||
| 566 | /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups | ||
| 567 | /* .synchronize = */ ggml_backend_metal_synchronize, | ||
| 568 | /* .graph_plan_create = */ NULL, | ||
| 569 | /* .graph_plan_free = */ NULL, | ||
| 570 | /* .graph_plan_update = */ NULL, | ||
| 571 | /* .graph_plan_compute = */ NULL, | ||
| 572 | /* .graph_compute = */ ggml_backend_metal_graph_compute, | ||
| 573 | /* .event_record = */ ggml_backend_metal_event_record, | ||
| 574 | /* .event_wait = */ ggml_backend_metal_event_wait, | ||
| 575 | /* .graph_optimize = */ ggml_backend_metal_graph_optimize, | ||
| 576 | }; | ||
| 577 | |||
| 578 | static ggml_guid_t ggml_backend_metal_guid(void) { | ||
| 579 | static ggml_guid guid = { 0x81, 0xa1, 0x8b, 0x1e, 0x71, 0xec, 0x79, 0xed, 0x2b, 0x85, 0xdc, 0x8a, 0x61, 0x98, 0x30, 0xe6 }; | ||
| 580 | return &guid; | ||
| 581 | } | ||
| 582 | |||
| 583 | ggml_backend_t ggml_backend_metal_init(void) { | ||
| 584 | ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0); | ||
| 585 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 586 | |||
| 587 | ggml_metal_t ctx = ggml_metal_init(ctx_dev); | ||
| 588 | if (ctx == NULL) { | ||
| 589 | GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); | ||
| 590 | return NULL; | ||
| 591 | } | ||
| 592 | |||
| 593 | ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend)); | ||
| 594 | |||
| 595 | *backend = { | ||
| 596 | /* .guid = */ ggml_backend_metal_guid(), | ||
| 597 | /* .interface = */ ggml_backend_metal_i, | ||
| 598 | /* .device = */ dev, | ||
| 599 | /* .context = */ ctx, | ||
| 600 | }; | ||
| 601 | |||
| 602 | ggml_backend_metal_set_n_cb(backend, 1); | ||
| 603 | |||
| 604 | return backend; | ||
| 605 | } | ||
| 606 | |||
| 607 | bool ggml_backend_is_metal(ggml_backend_t backend) { | ||
| 608 | return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid()); | ||
| 609 | } | ||
| 610 | |||
| 611 | void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { | ||
| 612 | GGML_ASSERT(ggml_backend_is_metal(backend)); | ||
| 613 | |||
| 614 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 615 | |||
| 616 | ggml_metal_set_abort_callback(ctx, abort_callback, user_data); | ||
| 617 | } | ||
| 618 | |||
| 619 | bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { | ||
| 620 | GGML_ASSERT(ggml_backend_is_metal(backend)); | ||
| 621 | |||
| 622 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 623 | |||
| 624 | return ggml_metal_supports_family(ctx, family); | ||
| 625 | } | ||
| 626 | |||
| 627 | void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { | ||
| 628 | GGML_ASSERT(ggml_backend_is_metal(backend)); | ||
| 629 | |||
| 630 | ggml_metal_t ctx = (ggml_metal_t)backend->context; | ||
| 631 | |||
| 632 | ggml_metal_capture_next_compute(ctx); | ||
| 633 | } | ||
| 634 | |||
| 635 | // backend device | ||
| 636 | |||
| 637 | static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { | ||
| 638 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 639 | |||
| 640 | const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); | ||
| 641 | |||
| 642 | return props_dev->name; | ||
| 643 | } | ||
| 644 | |||
| 645 | static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { | ||
| 646 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 647 | |||
| 648 | return ggml_metal_device_get_props(ctx_dev)->desc; | ||
| 649 | } | ||
| 650 | |||
| 651 | static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { | ||
| 652 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 653 | |||
| 654 | ggml_metal_device_get_memory(ctx_dev, free, total); | ||
| 655 | } | ||
| 656 | |||
| 657 | static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) { | ||
| 658 | return GGML_BACKEND_DEVICE_TYPE_GPU; | ||
| 659 | |||
| 660 | GGML_UNUSED(dev); | ||
| 661 | } | ||
| 662 | |||
| 663 | static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { | ||
| 664 | props->name = ggml_backend_metal_device_get_name(dev); | ||
| 665 | props->description = ggml_backend_metal_device_get_description(dev); | ||
| 666 | props->type = ggml_backend_metal_device_get_type(dev); | ||
| 667 | |||
| 668 | ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); | ||
| 669 | |||
| 670 | props->caps = { | ||
| 671 | /* .async = */ true, | ||
| 672 | /* .host_buffer = */ false, | ||
| 673 | /* .buffer_from_host_ptr = */ true, | ||
| 674 | /* .events = */ true, | ||
| 675 | }; | ||
| 676 | } | ||
| 677 | |||
| 678 | static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) { | ||
| 679 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 680 | |||
| 681 | ggml_metal_t ctx = ggml_metal_init(ctx_dev); | ||
| 682 | if (ctx == NULL) { | ||
| 683 | GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__); | ||
| 684 | return NULL; | ||
| 685 | } | ||
| 686 | |||
| 687 | ggml_backend_t backend = (ggml_backend_t) malloc(sizeof(ggml_backend)); | ||
| 688 | |||
| 689 | *backend = { | ||
| 690 | /* .guid = */ ggml_backend_metal_guid(), | ||
| 691 | /* .interface = */ ggml_backend_metal_i, | ||
| 692 | /* .device = */ dev, | ||
| 693 | /* .context = */ ctx, | ||
| 694 | }; | ||
| 695 | |||
| 696 | ggml_backend_metal_set_n_cb(backend, 1); | ||
| 697 | |||
| 698 | return backend; | ||
| 699 | |||
| 700 | GGML_UNUSED(params); | ||
| 701 | } | ||
| 702 | |||
| 703 | static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) { | ||
| 704 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 705 | |||
| 706 | const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); | ||
| 707 | |||
| 708 | return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device); | ||
| 709 | } | ||
| 710 | |||
| 711 | static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { | ||
| 712 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 713 | |||
| 714 | ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size); | ||
| 715 | |||
| 716 | const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); | ||
| 717 | |||
| 718 | return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size); | ||
| 719 | } | ||
| 720 | |||
| 721 | static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { | ||
| 722 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 723 | |||
| 724 | return ggml_metal_device_supports_op(ctx_dev, op); | ||
| 725 | } | ||
| 726 | |||
| 727 | static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { | ||
| 728 | return | ||
| 729 | buft->device == dev && ( | ||
| 730 | buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || | ||
| 731 | buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || | ||
| 732 | buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name); | ||
| 733 | |||
| 734 | GGML_UNUSED(dev); | ||
| 735 | } | ||
| 736 | |||
| 737 | static int64_t get_op_batch_size(const ggml_tensor * op) { | ||
| 738 | switch (op->op) { | ||
| 739 | case GGML_OP_MUL_MAT: | ||
| 740 | return op->ne[1]; | ||
| 741 | case GGML_OP_MUL_MAT_ID: | ||
| 742 | return op->ne[2]; | ||
| 743 | default: | ||
| 744 | return ggml_nrows(op); | ||
| 745 | } | ||
| 746 | } | ||
| 747 | |||
| 748 | static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { | ||
| 749 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 750 | |||
| 751 | return (op->op == GGML_OP_MUL_MAT || | ||
| 752 | op->op == GGML_OP_MUL_MAT_ID) && | ||
| 753 | get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size; | ||
| 754 | } | ||
| 755 | |||
| 756 | static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) { | ||
| 757 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 758 | |||
| 759 | ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev); | ||
| 760 | GGML_ASSERT(event); | ||
| 761 | |||
| 762 | ggml_backend_event_t ev = new ggml_backend_event { | ||
| 763 | /* .device = */ dev, | ||
| 764 | /* .context = */ event, | ||
| 765 | }; | ||
| 766 | |||
| 767 | return ev; | ||
| 768 | } | ||
| 769 | |||
| 770 | static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { | ||
| 771 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 772 | |||
| 773 | ggml_metal_event_t ev = (ggml_metal_event_t)event->context; | ||
| 774 | |||
| 775 | ggml_metal_device_event_free(ctx_dev, ev); | ||
| 776 | |||
| 777 | delete event; | ||
| 778 | } | ||
| 779 | |||
| 780 | static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { | ||
| 781 | ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; | ||
| 782 | |||
| 783 | ggml_metal_event_t evt = (ggml_metal_event_t)event->context; | ||
| 784 | |||
| 785 | ggml_metal_device_event_synchronize(ctx_dev, evt); | ||
| 786 | } | ||
| 787 | |||
| 788 | static ggml_backend_device_i ggml_backend_metal_device_i = { | ||
| 789 | /* .get_name = */ ggml_backend_metal_device_get_name, | ||
| 790 | /* .get_description = */ ggml_backend_metal_device_get_description, | ||
| 791 | /* .get_memory = */ ggml_backend_metal_device_get_memory, | ||
| 792 | /* .get_type = */ ggml_backend_metal_device_get_type, | ||
| 793 | /* .get_props = */ ggml_backend_metal_device_get_props, | ||
| 794 | /* .init_backend = */ ggml_backend_metal_device_init_backend, | ||
| 795 | /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, | ||
| 796 | /* .get_host_buffer_type = */ NULL, | ||
| 797 | /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, | ||
| 798 | /* .supports_op = */ ggml_backend_metal_device_supports_op, | ||
| 799 | /* .supports_buft = */ ggml_backend_metal_device_supports_buft, | ||
| 800 | /* .offload_op = */ ggml_backend_metal_device_offload_op, | ||
| 801 | /* .event_new = */ ggml_backend_metal_device_event_new, | ||
| 802 | /* .event_free = */ ggml_backend_metal_device_event_free, | ||
| 803 | /* .event_synchronize = */ ggml_backend_metal_device_event_synchronize, | ||
| 804 | }; | ||
| 805 | |||
| 806 | // backend registry | ||
| 807 | |||
| 808 | struct ggml_backend_metal_reg { | ||
| 809 | std::vector<ggml_backend_dev_t> devices; | ||
| 810 | }; | ||
| 811 | |||
| 812 | typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t; | ||
| 813 | |||
| 814 | static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) { | ||
| 815 | ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg; | ||
| 816 | |||
| 817 | return ctx; | ||
| 818 | } | ||
| 819 | |||
| 820 | static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) { | ||
| 821 | delete ctx; | ||
| 822 | } | ||
| 823 | |||
| 824 | struct ggml_backend_metal_reg_deleter { | ||
| 825 | void operator()(ggml_backend_metal_reg_t ctx) { | ||
| 826 | ggml_backend_metal_reg_free(ctx); | ||
| 827 | } | ||
| 828 | }; | ||
| 829 | |||
| 830 | typedef std::unique_ptr<struct ggml_backend_metal_reg, ggml_backend_metal_reg_deleter> ggml_backend_metal_reg_ptr; | ||
| 831 | |||
| 832 | static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { | ||
| 833 | return GGML_METAL_NAME; | ||
| 834 | |||
| 835 | GGML_UNUSED(reg); | ||
| 836 | } | ||
| 837 | |||
| 838 | static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { | ||
| 839 | ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; | ||
| 840 | return ctx->devices.size(); | ||
| 841 | } | ||
| 842 | |||
| 843 | static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { | ||
| 844 | ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; | ||
| 845 | GGML_ASSERT(index < ctx->devices.size()); | ||
| 846 | return ctx->devices[index]; | ||
| 847 | } | ||
| 848 | |||
| 849 | static ggml_backend_feature g_ggml_backend_metal_features[] = { | ||
| 850 | #if defined(GGML_METAL_EMBED_LIBRARY) | ||
| 851 | { "EMBED_LIBRARY", "1" }, | ||
| 852 | #endif | ||
| 853 | { NULL, NULL }, | ||
| 854 | }; | ||
| 855 | |||
| 856 | static ggml_backend_feature * ggml_backend_metal_get_features(ggml_backend_reg_t reg) { | ||
| 857 | return g_ggml_backend_metal_features; | ||
| 858 | |||
| 859 | GGML_UNUSED(reg); | ||
| 860 | } | ||
| 861 | |||
| 862 | static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const char * name) { | ||
| 863 | if (strcmp(name, "ggml_backend_get_features") == 0) { | ||
| 864 | return (void *)ggml_backend_metal_get_features; | ||
| 865 | } | ||
| 866 | |||
| 867 | return NULL; | ||
| 868 | |||
| 869 | GGML_UNUSED(reg); | ||
| 870 | } | ||
| 871 | |||
| 872 | static ggml_backend_reg_i ggml_backend_metal_reg_i = { | ||
| 873 | /* .get_name = */ ggml_backend_metal_reg_get_name, | ||
| 874 | /* .get_device_count = */ ggml_backend_metal_reg_device_count, | ||
| 875 | /* .get_device = */ ggml_backend_metal_reg_device_get, | ||
| 876 | /* .get_proc_address = */ ggml_backend_metal_get_proc_address, | ||
| 877 | }; | ||
| 878 | |||
| 879 | static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) { | ||
| 880 | return new ggml_backend_device { | ||
| 881 | /* .iface = */ ggml_backend_metal_device_i, | ||
| 882 | /* .reg = */ reg, | ||
| 883 | /* .context = */ ggml_metal_device_get(device), | ||
| 884 | }; | ||
| 885 | } | ||
| 886 | |||
| 887 | static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) { | ||
| 888 | delete dev; | ||
| 889 | } | ||
| 890 | |||
| 891 | struct ggml_backend_device_deleter { | ||
| 892 | void operator()(ggml_backend_dev_t ctx) { | ||
| 893 | ggml_backend_metal_device_free(ctx); | ||
| 894 | } | ||
| 895 | }; | ||
| 896 | |||
| 897 | typedef std::unique_ptr<ggml_backend_device, ggml_backend_device_deleter> ggml_backend_device_ptr; | ||
| 898 | |||
| 899 | ggml_backend_reg_t ggml_backend_metal_reg(void) { | ||
| 900 | static ggml_backend_reg reg; | ||
| 901 | static bool initialized = false; | ||
| 902 | |||
| 903 | { | ||
| 904 | static std::mutex mutex; | ||
| 905 | std::lock_guard<std::mutex> lock(mutex); | ||
| 906 | |||
| 907 | const char * env = getenv("GGML_METAL_DEVICES"); | ||
| 908 | if (env) { | ||
| 909 | g_devices = atoi(env); | ||
| 910 | } | ||
| 911 | |||
| 912 | static std::vector<ggml_backend_device_ptr> devs; | ||
| 913 | |||
| 914 | if (!initialized) { | ||
| 915 | static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); | ||
| 916 | |||
| 917 | for (int i = 0; i < g_devices; ++i) { | ||
| 918 | auto * dev = ggml_backend_metal_device_init(®, i); | ||
| 919 | devs.emplace_back(dev); | ||
| 920 | |||
| 921 | reg_ctx->devices.push_back(dev); | ||
| 922 | } | ||
| 923 | |||
| 924 | reg = { | ||
| 925 | /* .api_version = */ GGML_BACKEND_API_VERSION, | ||
| 926 | /* .iface = */ ggml_backend_metal_reg_i, | ||
| 927 | /* .context = */ reg_ctx.get(), | ||
| 928 | }; | ||
| 929 | } | ||
| 930 | |||
| 931 | initialized = true; | ||
| 932 | } | ||
| 933 | |||
| 934 | return ® | ||
| 935 | } | ||
| 936 | |||
| 937 | GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) | ||
diff --git a/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal b/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal new file mode 100644 index 0000000..0036ba9 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-metal/ggml-metal.metal | |||
| @@ -0,0 +1,9798 @@ | |||
| 1 | #define GGML_COMMON_DECL_METAL | ||
| 2 | #define GGML_COMMON_IMPL_METAL | ||
| 3 | #if defined(GGML_METAL_EMBED_LIBRARY) | ||
| 4 | __embed_ggml-common.h__ | ||
| 5 | #else | ||
| 6 | #include "ggml-common.h" | ||
| 7 | #endif | ||
| 8 | #include "ggml-metal-impl.h" | ||
| 9 | |||
| 10 | #include <metal_stdlib> | ||
| 11 | |||
| 12 | #ifdef GGML_METAL_HAS_TENSOR | ||
| 13 | #include <metal_tensor> | ||
| 14 | |||
| 15 | #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> | ||
| 16 | #endif | ||
| 17 | |||
| 18 | using namespace metal; | ||
| 19 | |||
| 20 | #define MAX(x, y) ((x) > (y) ? (x) : (y)) | ||
| 21 | #define MIN(x, y) ((x) < (y) ? (x) : (y)) | ||
| 22 | #define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } | ||
| 23 | |||
| 24 | #define PAD2(x, n) (((x) + (n) - 1) & ~((n) - 1)) | ||
| 25 | |||
| 26 | #define FOR_UNROLL(x) _Pragma("clang loop unroll(full)") for (x) | ||
| 27 | |||
| 28 | #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 | ||
| 29 | |||
| 30 | // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf | ||
| 31 | // | ||
| 32 | // cmd: | ||
| 33 | // .../usr/bin/metal -dM -E -c ggml/src/ggml-metal/ggml-metal.metal | ||
| 34 | // .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal/ggml-metal.metal | ||
| 35 | // | ||
| 36 | #if __METAL_VERSION__ < 310 && defined(GGML_METAL_HAS_BF16) | ||
| 37 | #undef GGML_METAL_HAS_BF16 | ||
| 38 | #endif | ||
| 39 | |||
| 40 | #if defined(GGML_METAL_HAS_BF16) | ||
| 41 | typedef matrix<bfloat, 4, 4> bfloat4x4; | ||
| 42 | typedef matrix<bfloat, 2, 4> bfloat2x4; | ||
| 43 | #endif | ||
| 44 | |||
| 45 | constexpr constant static float kvalues_iq4nl_f[16] = { | ||
| 46 | -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f | ||
| 47 | }; | ||
| 48 | |||
| 49 | constexpr constant static float kvalues_mxfp4_f[16] = { | ||
| 50 | 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f | ||
| 51 | }; | ||
| 52 | |||
| 53 | static inline int best_index_int8(int n, constant float * val, float x) { | ||
| 54 | if (x <= val[0]) return 0; | ||
| 55 | if (x >= val[n-1]) return n-1; | ||
| 56 | int ml = 0, mu = n-1; | ||
| 57 | while (mu-ml > 1) { | ||
| 58 | int mav = (ml+mu)/2; | ||
| 59 | if (x < val[mav]) mu = mav; else ml = mav; | ||
| 60 | } | ||
| 61 | return x - val[mu-1] < val[mu] - x ? mu-1 : mu; | ||
| 62 | } | ||
| 63 | |||
| 64 | static inline float e8m0_to_fp32(uint8_t x) { | ||
| 65 | uint32_t bits; | ||
| 66 | |||
| 67 | if (x == 0) { | ||
| 68 | bits = 0x00400000; | ||
| 69 | } else { | ||
| 70 | bits = (uint32_t) x << 23; | ||
| 71 | } | ||
| 72 | |||
| 73 | return as_type<float>(bits); | ||
| 74 | } | ||
| 75 | |||
| 76 | static inline float dot(float x, float y) { | ||
| 77 | return x*y; | ||
| 78 | } | ||
| 79 | |||
| 80 | // NOTE: this is not dequantizing - we are simply fitting the template | ||
| 81 | template <typename type4x4> | ||
| 82 | void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { | ||
| 83 | reg = (type4x4)(*src); | ||
| 84 | } | ||
| 85 | |||
| 86 | template <typename type4> | ||
| 87 | void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) { | ||
| 88 | reg = (type4)(*src); | ||
| 89 | } | ||
| 90 | |||
| 91 | template <typename type4x4> | ||
| 92 | void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { | ||
| 93 | reg = (type4x4)(*src); | ||
| 94 | } | ||
| 95 | |||
| 96 | template <typename type4> | ||
| 97 | void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) { | ||
| 98 | reg = (type4)(*(src)); | ||
| 99 | } | ||
| 100 | |||
| 101 | #if defined(GGML_METAL_HAS_BF16) | ||
| 102 | template <typename type4x4> | ||
| 103 | void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { | ||
| 104 | reg = (type4x4)(*src); | ||
| 105 | } | ||
| 106 | |||
| 107 | template <typename type4> | ||
| 108 | void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg) { | ||
| 109 | reg = (type4)(*(src)); | ||
| 110 | } | ||
| 111 | #endif | ||
| 112 | |||
| 113 | template <typename type4x4> | ||
| 114 | void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { | ||
| 115 | device const uint16_t * qs = ((device const uint16_t *)xb + 1); | ||
| 116 | const float d1 = il ? (xb->d / 16.h) : xb->d; | ||
| 117 | const float d2 = d1 / 256.f; | ||
| 118 | const float md = -8.h * xb->d; | ||
| 119 | const ushort mask0 = il ? 0x00F0 : 0x000F; | ||
| 120 | const ushort mask1 = mask0 << 8; | ||
| 121 | |||
| 122 | float4x4 reg_f; | ||
| 123 | |||
| 124 | for (int i = 0; i < 8; i++) { | ||
| 125 | reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md; | ||
| 126 | reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md; | ||
| 127 | } | ||
| 128 | |||
| 129 | reg = (type4x4) reg_f; | ||
| 130 | } | ||
| 131 | |||
| 132 | template <typename type4> | ||
| 133 | void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) { | ||
| 134 | device const uint16_t * qs = ((device const uint16_t *)xb + 1); | ||
| 135 | const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; | ||
| 136 | const float d2 = d1 / 256.f; | ||
| 137 | const float md = -8.h * xb->d; | ||
| 138 | const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; | ||
| 139 | const ushort mask1 = mask0 << 8; | ||
| 140 | |||
| 141 | for (int i = 0; i < 2; i++) { | ||
| 142 | reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md; | ||
| 143 | reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md; | ||
| 144 | } | ||
| 145 | } | ||
| 146 | |||
| 147 | void quantize_q4_0(device const float * src, device block_q4_0 & dst) { | ||
| 148 | #pragma METAL fp math_mode(safe) | ||
| 149 | float amax = 0.0f; // absolute max | ||
| 150 | float max = 0.0f; | ||
| 151 | |||
| 152 | for (int j = 0; j < QK4_0; j++) { | ||
| 153 | const float v = src[j]; | ||
| 154 | if (amax < fabs(v)) { | ||
| 155 | amax = fabs(v); | ||
| 156 | max = v; | ||
| 157 | } | ||
| 158 | } | ||
| 159 | |||
| 160 | const float d = max / -8; | ||
| 161 | const float id = d ? 1.0f/d : 0.0f; | ||
| 162 | |||
| 163 | dst.d = d; | ||
| 164 | |||
| 165 | for (int j = 0; j < QK4_0/2; ++j) { | ||
| 166 | const float x0 = src[0 + j]*id; | ||
| 167 | const float x1 = src[QK4_0/2 + j]*id; | ||
| 168 | |||
| 169 | const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); | ||
| 170 | const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); | ||
| 171 | |||
| 172 | dst.qs[j] = xi0; | ||
| 173 | dst.qs[j] |= xi1 << 4; | ||
| 174 | } | ||
| 175 | } | ||
| 176 | |||
| 177 | void quantize_q4_1(device const float * src, device block_q4_1 & dst) { | ||
| 178 | #pragma METAL fp math_mode(safe) | ||
| 179 | float min = FLT_MAX; | ||
| 180 | float max = -FLT_MAX; | ||
| 181 | |||
| 182 | for (int j = 0; j < QK4_1; j++) { | ||
| 183 | const float v = src[j]; | ||
| 184 | if (min > v) min = v; | ||
| 185 | if (max < v) max = v; | ||
| 186 | } | ||
| 187 | |||
| 188 | const float d = (max - min) / ((1 << 4) - 1); | ||
| 189 | const float id = d ? 1.0f/d : 0.0f; | ||
| 190 | |||
| 191 | dst.d = d; | ||
| 192 | dst.m = min; | ||
| 193 | |||
| 194 | for (int j = 0; j < QK4_1/2; ++j) { | ||
| 195 | const float x0 = (src[0 + j] - min)*id; | ||
| 196 | const float x1 = (src[QK4_1/2 + j] - min)*id; | ||
| 197 | |||
| 198 | const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); | ||
| 199 | const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); | ||
| 200 | |||
| 201 | dst.qs[j] = xi0; | ||
| 202 | dst.qs[j] |= xi1 << 4; | ||
| 203 | } | ||
| 204 | } | ||
| 205 | |||
| 206 | void quantize_q5_0(device const float * src, device block_q5_0 & dst) { | ||
| 207 | #pragma METAL fp math_mode(safe) | ||
| 208 | float amax = 0.0f; // absolute max | ||
| 209 | float max = 0.0f; | ||
| 210 | |||
| 211 | for (int j = 0; j < QK5_0; j++) { | ||
| 212 | const float v = src[j]; | ||
| 213 | if (amax < fabs(v)) { | ||
| 214 | amax = fabs(v); | ||
| 215 | max = v; | ||
| 216 | } | ||
| 217 | } | ||
| 218 | |||
| 219 | const float d = max / -16; | ||
| 220 | const float id = d ? 1.0f/d : 0.0f; | ||
| 221 | |||
| 222 | dst.d = d; | ||
| 223 | |||
| 224 | uint32_t qh = 0; | ||
| 225 | for (int j = 0; j < QK5_0/2; ++j) { | ||
| 226 | const float x0 = src[0 + j]*id; | ||
| 227 | const float x1 = src[QK5_0/2 + j]*id; | ||
| 228 | |||
| 229 | const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); | ||
| 230 | const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); | ||
| 231 | |||
| 232 | dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); | ||
| 233 | qh |= ((xi0 & 0x10u) >> 4) << (j + 0); | ||
| 234 | qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); | ||
| 235 | } | ||
| 236 | |||
| 237 | thread const uint8_t * qh8 = (thread const uint8_t *)&qh; | ||
| 238 | |||
| 239 | for (int j = 0; j < 4; ++j) { | ||
| 240 | dst.qh[j] = qh8[j]; | ||
| 241 | } | ||
| 242 | } | ||
| 243 | |||
| 244 | void quantize_q5_1(device const float * src, device block_q5_1 & dst) { | ||
| 245 | #pragma METAL fp math_mode(safe) | ||
| 246 | float max = src[0]; | ||
| 247 | float min = src[0]; | ||
| 248 | |||
| 249 | for (int j = 1; j < QK5_1; j++) { | ||
| 250 | const float v = src[j]; | ||
| 251 | min = v < min ? v : min; | ||
| 252 | max = v > max ? v : max; | ||
| 253 | } | ||
| 254 | |||
| 255 | const float d = (max - min) / 31; | ||
| 256 | const float id = d ? 1.0f/d : 0.0f; | ||
| 257 | |||
| 258 | dst.d = d; | ||
| 259 | dst.m = min; | ||
| 260 | |||
| 261 | uint32_t qh = 0; | ||
| 262 | for (int j = 0; j < QK5_1/2; ++j) { | ||
| 263 | const float x0 = (src[0 + j] - min)*id; | ||
| 264 | const float x1 = (src[QK5_1/2 + j] - min)*id; | ||
| 265 | |||
| 266 | const uint8_t xi0 = (uint8_t)(x0 + 0.5f); | ||
| 267 | const uint8_t xi1 = (uint8_t)(x1 + 0.5f); | ||
| 268 | |||
| 269 | dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); | ||
| 270 | qh |= ((xi0 & 0x10u) >> 4) << (j + 0); | ||
| 271 | qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); | ||
| 272 | } | ||
| 273 | |||
| 274 | thread const uint8_t * qh8 = (thread const uint8_t *)&qh; | ||
| 275 | |||
| 276 | for (int j = 0; j < 4; ++j) { | ||
| 277 | dst.qh[j] = qh8[j]; | ||
| 278 | } | ||
| 279 | } | ||
| 280 | |||
| 281 | void quantize_q8_0(device const float * src, device block_q8_0 & dst) { | ||
| 282 | #pragma METAL fp math_mode(safe) | ||
| 283 | float amax = 0.0f; // absolute max | ||
| 284 | |||
| 285 | for (int j = 0; j < QK8_0; j++) { | ||
| 286 | const float v = src[j]; | ||
| 287 | amax = MAX(amax, fabs(v)); | ||
| 288 | } | ||
| 289 | |||
| 290 | const float d = amax / ((1 << 7) - 1); | ||
| 291 | const float id = d ? 1.0f/d : 0.0f; | ||
| 292 | |||
| 293 | dst.d = d; | ||
| 294 | |||
| 295 | for (int j = 0; j < QK8_0; ++j) { | ||
| 296 | const float x0 = src[j]*id; | ||
| 297 | |||
| 298 | dst.qs[j] = round(x0); | ||
| 299 | } | ||
| 300 | } | ||
| 301 | |||
| 302 | void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) { | ||
| 303 | #pragma METAL fp math_mode(safe) | ||
| 304 | float amax = 0.0f; // absolute max | ||
| 305 | float max = 0.0f; | ||
| 306 | |||
| 307 | for (int j = 0; j < QK4_NL; j++) { | ||
| 308 | const float v = src[j]; | ||
| 309 | if (amax < fabs(v)) { | ||
| 310 | amax = fabs(v); | ||
| 311 | max = v; | ||
| 312 | } | ||
| 313 | } | ||
| 314 | |||
| 315 | const float d = max / kvalues_iq4nl_f[0]; | ||
| 316 | const float id = d ? 1.0f/d : 0.0f; | ||
| 317 | |||
| 318 | float sumqx = 0, sumq2 = 0; | ||
| 319 | for (int j = 0; j < QK4_NL/2; ++j) { | ||
| 320 | const float x0 = src[0 + j]*id; | ||
| 321 | const float x1 = src[QK4_NL/2 + j]*id; | ||
| 322 | |||
| 323 | const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0); | ||
| 324 | const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1); | ||
| 325 | |||
| 326 | dst.qs[j] = xi0 | (xi1 << 4); | ||
| 327 | |||
| 328 | const float v0 = kvalues_iq4nl_f[xi0]; | ||
| 329 | const float v1 = kvalues_iq4nl_f[xi1]; | ||
| 330 | const float w0 = src[0 + j]*src[0 + j]; | ||
| 331 | const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j]; | ||
| 332 | sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j]; | ||
| 333 | sumq2 += w0*v0*v0 + w1*v1*v1; | ||
| 334 | |||
| 335 | } | ||
| 336 | |||
| 337 | dst.d = sumq2 > 0 ? sumqx/sumq2 : d; | ||
| 338 | } | ||
| 339 | |||
| 340 | template <typename type4x4> | ||
| 341 | void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { | ||
| 342 | device const uint16_t * qs = ((device const uint16_t *)xb + 2); | ||
| 343 | const float d1 = il ? (xb->d / 16.h) : xb->d; | ||
| 344 | const float d2 = d1 / 256.f; | ||
| 345 | const float m = xb->m; | ||
| 346 | const ushort mask0 = il ? 0x00F0 : 0x000F; | ||
| 347 | const ushort mask1 = mask0 << 8; | ||
| 348 | |||
| 349 | float4x4 reg_f; | ||
| 350 | |||
| 351 | for (int i = 0; i < 8; i++) { | ||
| 352 | reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m; | ||
| 353 | reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m; | ||
| 354 | } | ||
| 355 | |||
| 356 | reg = (type4x4) reg_f; | ||
| 357 | } | ||
| 358 | |||
| 359 | template <typename type4> | ||
| 360 | void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) { | ||
| 361 | device const uint16_t * qs = ((device const uint16_t *)xb + 2); | ||
| 362 | const float d1 = (il/4) ? (xb->d / 16.h) : xb->d; | ||
| 363 | const float d2 = d1 / 256.f; | ||
| 364 | const float m = xb->m; | ||
| 365 | const ushort mask0 = (il/4) ? 0x00F0 : 0x000F; | ||
| 366 | const ushort mask1 = mask0 << 8; | ||
| 367 | |||
| 368 | for (int i = 0; i < 2; i++) { | ||
| 369 | reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m; | ||
| 370 | reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m; | ||
| 371 | } | ||
| 372 | } | ||
| 373 | |||
| 374 | template <typename type4x4> | ||
| 375 | void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) { | ||
| 376 | device const uint16_t * qs = ((device const uint16_t *)xb + 3); | ||
| 377 | const float d = xb->d; | ||
| 378 | const float md = -16.h * xb->d; | ||
| 379 | const ushort mask = il ? 0x00F0 : 0x000F; | ||
| 380 | |||
| 381 | const uint32_t qh = *((device const uint32_t *)xb->qh); | ||
| 382 | |||
| 383 | const int x_mv = il ? 4 : 0; | ||
| 384 | |||
| 385 | const int gh_mv = il ? 12 : 0; | ||
| 386 | const int gh_bk = il ? 0 : 4; | ||
| 387 | |||
| 388 | float4x4 reg_f; | ||
| 389 | |||
| 390 | for (int i = 0; i < 8; i++) { | ||
| 391 | // extract the 5-th bits for x0 and x1 | ||
| 392 | const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; | ||
| 393 | const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | ||
| 394 | |||
| 395 | // combine the 4-bits from qs with the 5th bit | ||
| 396 | const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); | ||
| 397 | const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | ||
| 398 | |||
| 399 | reg_f[i/2][2*(i%2) + 0] = d * x0 + md; | ||
| 400 | reg_f[i/2][2*(i%2) + 1] = d * x1 + md; | ||
| 401 | } | ||
| 402 | |||
| 403 | reg = (type4x4) reg_f; | ||
| 404 | } | ||
| 405 | |||
| 406 | template <typename type4> | ||
| 407 | void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) { | ||
| 408 | device const uint16_t * qs = ((device const uint16_t *)xb + 3); | ||
| 409 | const float d = xb->d; | ||
| 410 | const float md = -16.h * xb->d; | ||
| 411 | const ushort mask = (il/4) ? 0x00F0 : 0x000F; | ||
| 412 | |||
| 413 | const uint32_t qh = *((device const uint32_t *)xb->qh); | ||
| 414 | |||
| 415 | const int x_mv = (il/4) ? 4 : 0; | ||
| 416 | |||
| 417 | const int gh_mv = (il/4) ? 12 : 0; | ||
| 418 | const int gh_bk = (il/4) ? 0 : 4; | ||
| 419 | |||
| 420 | for (int ii = 0; ii < 2; ii++) { | ||
| 421 | int i = 2*(il%4) + ii; | ||
| 422 | |||
| 423 | // extract the 5-th bits for x0 and x1 | ||
| 424 | const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; | ||
| 425 | const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | ||
| 426 | |||
| 427 | // combine the 4-bits from qs with the 5th bit | ||
| 428 | const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); | ||
| 429 | const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | ||
| 430 | |||
| 431 | reg[2*ii + 0] = d * x0 + md; | ||
| 432 | reg[2*ii + 1] = d * x1 + md; | ||
| 433 | } | ||
| 434 | } | ||
| 435 | |||
| 436 | template <typename type4x4> | ||
| 437 | void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) { | ||
| 438 | device const uint16_t * qs = ((device const uint16_t *)xb + 4); | ||
| 439 | const float d = xb->d; | ||
| 440 | const float m = xb->m; | ||
| 441 | const ushort mask = il ? 0x00F0 : 0x000F; | ||
| 442 | |||
| 443 | const uint32_t qh = *((device const uint32_t *)xb->qh); | ||
| 444 | |||
| 445 | const int x_mv = il ? 4 : 0; | ||
| 446 | |||
| 447 | const int gh_mv = il ? 12 : 0; | ||
| 448 | const int gh_bk = il ? 0 : 4; | ||
| 449 | |||
| 450 | float4x4 reg_f; | ||
| 451 | |||
| 452 | for (int i = 0; i < 8; i++) { | ||
| 453 | // extract the 5-th bits for x0 and x1 | ||
| 454 | const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; | ||
| 455 | const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | ||
| 456 | |||
| 457 | // combine the 4-bits from qs with the 5th bit | ||
| 458 | const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); | ||
| 459 | const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | ||
| 460 | |||
| 461 | reg_f[i/2][2*(i%2) + 0] = d * x0 + m; | ||
| 462 | reg_f[i/2][2*(i%2) + 1] = d * x1 + m; | ||
| 463 | } | ||
| 464 | |||
| 465 | reg = (type4x4) reg_f; | ||
| 466 | } | ||
| 467 | |||
| 468 | template <typename type4> | ||
| 469 | void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) { | ||
| 470 | device const uint16_t * qs = ((device const uint16_t *)xb + 4); | ||
| 471 | const float d = xb->d; | ||
| 472 | const float m = xb->m; | ||
| 473 | const ushort mask = (il/4) ? 0x00F0 : 0x000F; | ||
| 474 | |||
| 475 | const uint32_t qh = *((device const uint32_t *)xb->qh); | ||
| 476 | |||
| 477 | const int x_mv = (il/4) ? 4 : 0; | ||
| 478 | |||
| 479 | const int gh_mv = (il/4) ? 12 : 0; | ||
| 480 | const int gh_bk = (il/4) ? 0 : 4; | ||
| 481 | |||
| 482 | for (int ii = 0; ii < 2; ii++) { | ||
| 483 | int i = 2*(il%4) + ii; | ||
| 484 | |||
| 485 | // extract the 5-th bits for x0 and x1 | ||
| 486 | const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; | ||
| 487 | const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; | ||
| 488 | |||
| 489 | // combine the 4-bits from qs with the 5th bit | ||
| 490 | const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); | ||
| 491 | const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); | ||
| 492 | |||
| 493 | reg[2*ii + 0] = d * x0 + m; | ||
| 494 | reg[2*ii + 1] = d * x1 + m; | ||
| 495 | } | ||
| 496 | } | ||
| 497 | |||
| 498 | template <typename type4x4> | ||
| 499 | void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { | ||
| 500 | device const int8_t * qs = ((device const int8_t *)xb->qs); | ||
| 501 | const float d = xb->d; | ||
| 502 | |||
| 503 | float4x4 reg_f; | ||
| 504 | |||
| 505 | for (int i = 0; i < 16; i++) { | ||
| 506 | reg_f[i/4][i%4] = (qs[i + 16*il] * d); | ||
| 507 | } | ||
| 508 | |||
| 509 | reg = (type4x4) reg_f; | ||
| 510 | } | ||
| 511 | |||
| 512 | template <typename type4> | ||
| 513 | void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) { | ||
| 514 | device const int8_t * qs = ((device const int8_t *)xb->qs); | ||
| 515 | const float d = xb->d; | ||
| 516 | |||
| 517 | for (int i = 0; i < 4; i++) { | ||
| 518 | reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); | ||
| 519 | } | ||
| 520 | } | ||
| 521 | |||
| 522 | template <typename type4x4> | ||
| 523 | void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) { | ||
| 524 | device const uint8_t * q2 = (device const uint8_t *)xb->qs; | ||
| 525 | |||
| 526 | const float d = e8m0_to_fp32(xb->e); | ||
| 527 | const uint8_t shr = il >= 1 ? 4 : 0; | ||
| 528 | |||
| 529 | for (int i = 0; i < 4; ++i) { | ||
| 530 | reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F]; | ||
| 531 | reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F]; | ||
| 532 | reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F]; | ||
| 533 | reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F]; | ||
| 534 | } | ||
| 535 | } | ||
| 536 | |||
| 537 | template <typename type4> | ||
| 538 | void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) { | ||
| 539 | device const uint8_t * q2 = (device const uint8_t *)xb->qs; | ||
| 540 | |||
| 541 | const float d = e8m0_to_fp32(xb->e); | ||
| 542 | const short il4 = il%4; | ||
| 543 | |||
| 544 | const uint8_t shr = il >= 4 ? 4 : 0; | ||
| 545 | |||
| 546 | reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F]; | ||
| 547 | reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F]; | ||
| 548 | reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F]; | ||
| 549 | reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F]; | ||
| 550 | } | ||
| 551 | |||
| 552 | template <typename type4x4> | ||
| 553 | void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { | ||
| 554 | const float d = xb->d; | ||
| 555 | const float min = xb->dmin; | ||
| 556 | device const uint8_t * q = (device const uint8_t *)xb->qs; | ||
| 557 | float dl, ml; | ||
| 558 | uint8_t sc = xb->scales[il]; | ||
| 559 | |||
| 560 | q = q + 32*(il/8) + 16*(il&1); | ||
| 561 | il = (il/2)%4; | ||
| 562 | |||
| 563 | half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); | ||
| 564 | uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); | ||
| 565 | dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4); | ||
| 566 | for (int i = 0; i < 16; ++i) { | ||
| 567 | reg[i/4][i%4] = dl * (q[i] & mask) - ml; | ||
| 568 | } | ||
| 569 | } | ||
| 570 | |||
| 571 | template <typename type4x4> | ||
| 572 | void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { | ||
| 573 | const half d_all = xb->d; | ||
| 574 | device const uint8_t * q = (device const uint8_t *)xb->qs; | ||
| 575 | device const uint8_t * h = (device const uint8_t *)xb->hmask; | ||
| 576 | device const int8_t * scales = (device const int8_t *)xb->scales; | ||
| 577 | |||
| 578 | q = q + 32 * (il/8) + 16 * (il&1); | ||
| 579 | h = h + 16 * (il&1); | ||
| 580 | uint8_t m = 1 << (il/2); | ||
| 581 | uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \ | ||
| 582 | ((il/4)>0 ? 12 : 3); | ||
| 583 | uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; | ||
| 584 | uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; | ||
| 585 | int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) | ||
| 586 | : (scale_2&kmask2) | ((scale_1&kmask1) << 4); | ||
| 587 | float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); | ||
| 588 | const float ml = 4.f * dl; | ||
| 589 | |||
| 590 | il = (il/2) & 3; | ||
| 591 | const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); | ||
| 592 | const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); | ||
| 593 | dl *= coef; | ||
| 594 | |||
| 595 | for (int i = 0; i < 16; ++i) { | ||
| 596 | reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); | ||
| 597 | } | ||
| 598 | } | ||
| 599 | |||
| 600 | static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { | ||
| 601 | return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} | ||
| 602 | : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; | ||
| 603 | } | ||
| 604 | |||
| 605 | template <typename type4x4> | ||
| 606 | void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) { | ||
| 607 | device const uchar * q = xb->qs; | ||
| 608 | |||
| 609 | short is = (il/4) * 2; | ||
| 610 | q = q + (il/4) * 32 + 16 * (il&1); | ||
| 611 | il = il & 3; | ||
| 612 | const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); | ||
| 613 | const float d = il < 2 ? xb->d : xb->d / 16.h; | ||
| 614 | const float min = xb->dmin; | ||
| 615 | const float dl = d * sc[0]; | ||
| 616 | const float ml = min * sc[1]; | ||
| 617 | |||
| 618 | const ushort mask = il < 2 ? 0x0F : 0xF0; | ||
| 619 | for (int i = 0; i < 16; ++i) { | ||
| 620 | reg[i/4][i%4] = dl * (q[i] & mask) - ml; | ||
| 621 | } | ||
| 622 | } | ||
| 623 | |||
| 624 | template <typename type4x4> | ||
| 625 | void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) { | ||
| 626 | device const uint8_t * q = xb->qs; | ||
| 627 | device const uint8_t * qh = xb->qh; | ||
| 628 | |||
| 629 | short is = (il/4) * 2; | ||
| 630 | q = q + 32 * (il/4) + 16 * (il&1); | ||
| 631 | qh = qh + 16 * (il&1); | ||
| 632 | uint8_t ul = 1 << (il/2); | ||
| 633 | il = il & 3; | ||
| 634 | const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); | ||
| 635 | const float d = il < 2 ? xb->d : xb->d / 16.f; | ||
| 636 | const float min = xb->dmin; | ||
| 637 | const float dl = d * sc[0]; | ||
| 638 | const float ml = min * sc[1]; | ||
| 639 | |||
| 640 | const ushort mask = il<2 ? 0x0F : 0xF0; | ||
| 641 | const float qh_val = il<2 ? 16.f : 256.f; | ||
| 642 | for (int i = 0; i < 16; ++i) { | ||
| 643 | reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; | ||
| 644 | } | ||
| 645 | } | ||
| 646 | |||
| 647 | template <typename type4x4> | ||
| 648 | void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { | ||
| 649 | const half d_all = xb->d; | ||
| 650 | device const uint16_t * ql = (device const uint16_t *)xb->ql; | ||
| 651 | device const uint16_t * qh = (device const uint16_t *)xb->qh; | ||
| 652 | device const int8_t * scales = (device const int8_t *)xb->scales; | ||
| 653 | |||
| 654 | ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); | ||
| 655 | qh = qh + 16*(il/8) + 8*(il&1); | ||
| 656 | float sc = scales[(il%2) + 2 * ((il/2))]; | ||
| 657 | il = (il/2) & 3; | ||
| 658 | |||
| 659 | const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); | ||
| 660 | const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; | ||
| 661 | const float ml = d_all * sc * 32.f; | ||
| 662 | const float dl0 = d_all * sc; | ||
| 663 | const float dl1 = dl0 / 256.f; | ||
| 664 | const float dl2 = dl0 / (256.f * 256.f); | ||
| 665 | const float dl3 = dl0 / (256.f * 256.f * 256.f); | ||
| 666 | const uint8_t shr_h = il>2 ? 2 : 0; | ||
| 667 | const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); | ||
| 668 | const uint8_t shr_l = il>1 ? 4 : 0; | ||
| 669 | for (int i = 0; i < 4; ++i) { | ||
| 670 | const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; | ||
| 671 | const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; | ||
| 672 | const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); | ||
| 673 | reg[i][0] = dl0 * ((half)(q & 0xFF)) - ml; | ||
| 674 | reg[i][1] = dl1 * ((float)(q & 0xFF00)) - ml; | ||
| 675 | reg[i][2] = dl2 * ((float)(q & 0xFF0000)) - ml; | ||
| 676 | reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml; | ||
| 677 | } | ||
| 678 | } | ||
| 679 | |||
| 680 | template <typename type4x4> | ||
| 681 | void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) { | ||
| 682 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 683 | const float d = xb->d; | ||
| 684 | const int ib32 = il/2; | ||
| 685 | il = il%2; | ||
| 686 | // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 | ||
| 687 | // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's. | ||
| 688 | device const uint16_t * q2 = xb->qs + 4*ib32; | ||
| 689 | const uint32_t aux32_g = q2[0] | (q2[1] << 16); | ||
| 690 | const uint32_t aux32_s = q2[2] | (q2[3] << 16); | ||
| 691 | thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g; | ||
| 692 | const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f; | ||
| 693 | constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]); | ||
| 694 | uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127]; | ||
| 695 | for (int i = 0; i < 8; ++i) { | ||
| 696 | reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); | ||
| 697 | } | ||
| 698 | grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]); | ||
| 699 | signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127]; | ||
| 700 | for (int i = 0; i < 8; ++i) { | ||
| 701 | reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); | ||
| 702 | } | ||
| 703 | } | ||
| 704 | |||
| 705 | template <typename type4x4> | ||
| 706 | void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) { | ||
| 707 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 708 | const float d = xb->d; | ||
| 709 | const int ib32 = il/2; | ||
| 710 | il = il%2; | ||
| 711 | // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 | ||
| 712 | device const uint16_t * q2 = xb->qs + 4*ib32; | ||
| 713 | const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; | ||
| 714 | constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511)); | ||
| 715 | uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9]; | ||
| 716 | for (int i = 0; i < 8; ++i) { | ||
| 717 | reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); | ||
| 718 | } | ||
| 719 | grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511)); | ||
| 720 | signs = ksigns_iq2xs[q2[2*il+1] >> 9]; | ||
| 721 | for (int i = 0; i < 8; ++i) { | ||
| 722 | reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f); | ||
| 723 | } | ||
| 724 | } | ||
| 725 | |||
| 726 | template <typename type4x4> | ||
| 727 | void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) { | ||
| 728 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 729 | const float d = xb->d; | ||
| 730 | const int ib32 = il/2; | ||
| 731 | il = il%2; | ||
| 732 | // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 | ||
| 733 | device const uint8_t * q3 = xb->qs + 8*ib32; | ||
| 734 | device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32; | ||
| 735 | const uint32_t aux32 = gas[0] | (gas[1] << 16); | ||
| 736 | const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f; | ||
| 737 | constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]); | ||
| 738 | constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]); | ||
| 739 | uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127]; | ||
| 740 | for (int i = 0; i < 4; ++i) { | ||
| 741 | reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); | ||
| 742 | reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); | ||
| 743 | } | ||
| 744 | grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]); | ||
| 745 | grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]); | ||
| 746 | signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127]; | ||
| 747 | for (int i = 0; i < 4; ++i) { | ||
| 748 | reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f); | ||
| 749 | reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f); | ||
| 750 | } | ||
| 751 | } | ||
| 752 | |||
| 753 | template <typename type4x4> | ||
| 754 | void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) { | ||
| 755 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 756 | const float d = xb->d; | ||
| 757 | const int ib32 = il/2; | ||
| 758 | il = il%2; | ||
| 759 | // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 | ||
| 760 | device const uint8_t * qs = xb->qs + 8*ib32; | ||
| 761 | device const uint8_t * signs = xb->signs + 4*ib32 + 2*il; | ||
| 762 | const uint8_t qh = xb->qh[ib32] >> 4*il; | ||
| 763 | const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); | ||
| 764 | constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256))); | ||
| 765 | constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256))); | ||
| 766 | for (int i = 0; i < 4; ++i) { | ||
| 767 | reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); | ||
| 768 | reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); | ||
| 769 | } | ||
| 770 | grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256))); | ||
| 771 | grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256))); | ||
| 772 | for (int i = 0; i < 4; ++i) { | ||
| 773 | reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); | ||
| 774 | reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); | ||
| 775 | } | ||
| 776 | } | ||
| 777 | |||
| 778 | template <typename type4x4> | ||
| 779 | void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) { | ||
| 780 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 781 | const float d = xb->d; | ||
| 782 | const int ib32 = il/2; | ||
| 783 | il = il%2; | ||
| 784 | // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 | ||
| 785 | device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; | ||
| 786 | device const uint8_t * signs = qs + QK_K/8; | ||
| 787 | const uint8_t qh = xb->qh[ib32] >> 4*il; | ||
| 788 | const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f; | ||
| 789 | constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300))); | ||
| 790 | constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300))); | ||
| 791 | for (int i = 0; i < 8; ++i) { | ||
| 792 | reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]); | ||
| 793 | reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]); | ||
| 794 | } | ||
| 795 | } | ||
| 796 | |||
| 797 | template <typename type4x4> | ||
| 798 | void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { | ||
| 799 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 800 | const int ib32 = il/2; | ||
| 801 | il = il%2; | ||
| 802 | const float d = xb->d; | ||
| 803 | device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; | ||
| 804 | device const uint16_t * qh = xb->qh; | ||
| 805 | const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1); | ||
| 806 | const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA); | ||
| 807 | const uint16_t h = qh[ib32] >> 6*il; | ||
| 808 | constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700))); | ||
| 809 | constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700))); | ||
| 810 | for (int i = 0; i < 4; ++i) { | ||
| 811 | reg[0][i] = dl * (grid1[i] & 0xf) + ml; | ||
| 812 | reg[1][i] = dl * (grid1[i] >> 4) + ml; | ||
| 813 | reg[2][i] = dl * (grid2[i] & 0xf) + ml; | ||
| 814 | reg[3][i] = dl * (grid2[i] >> 4) + ml; | ||
| 815 | } | ||
| 816 | } | ||
| 817 | |||
| 818 | template <typename type4x4> | ||
| 819 | void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) { | ||
| 820 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 821 | const int ib32 = il/2; | ||
| 822 | il = il%2; | ||
| 823 | device const uint16_t * sc = (device const uint16_t *)xb->scales; | ||
| 824 | |||
| 825 | iq1m_scale_t scale; | ||
| 826 | scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); | ||
| 827 | const float d = scale.f16; | ||
| 828 | |||
| 829 | device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; | ||
| 830 | device const uint8_t * qh = xb->qh + 2*ib32 + il; | ||
| 831 | |||
| 832 | const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1); | ||
| 833 | const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); | ||
| 834 | const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); | ||
| 835 | constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); | ||
| 836 | constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); | ||
| 837 | for (int i = 0; i < 4; ++i) { | ||
| 838 | reg[0][i] = dl * (grid1[i] & 0xf) + ml1; | ||
| 839 | reg[1][i] = dl * (grid1[i] >> 4) + ml1; | ||
| 840 | reg[2][i] = dl * (grid2[i] & 0xf) + ml2; | ||
| 841 | reg[3][i] = dl * (grid2[i] >> 4) + ml2; | ||
| 842 | } | ||
| 843 | } | ||
| 844 | |||
| 845 | template <typename type4x4> | ||
| 846 | void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) { | ||
| 847 | device const uint16_t * q4 = (device const uint16_t *)xb->qs; | ||
| 848 | const float d = xb->d; | ||
| 849 | uint32_t aux32; | ||
| 850 | thread const uint8_t * q8 = (thread const uint8_t *)&aux32; | ||
| 851 | for (int i = 0; i < 4; ++i) { | ||
| 852 | aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f; | ||
| 853 | reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; | ||
| 854 | reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; | ||
| 855 | reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; | ||
| 856 | reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; | ||
| 857 | } | ||
| 858 | } | ||
| 859 | |||
| 860 | template <typename type4> | ||
| 861 | void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) { | ||
| 862 | device const uint16_t * q4 = (device const uint16_t *)xb->qs; | ||
| 863 | const float d = xb->d; | ||
| 864 | uint32_t aux32; | ||
| 865 | thread const uint8_t * q8 = (thread const uint8_t *)&aux32; | ||
| 866 | aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f; | ||
| 867 | reg[0] = d * kvalues_iq4nl_f[q8[0]]; | ||
| 868 | reg[1] = d * kvalues_iq4nl_f[q8[1]]; | ||
| 869 | reg[2] = d * kvalues_iq4nl_f[q8[2]]; | ||
| 870 | reg[3] = d * kvalues_iq4nl_f[q8[3]]; | ||
| 871 | } | ||
| 872 | |||
| 873 | template <typename type4x4> | ||
| 874 | void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) { | ||
| 875 | // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 | ||
| 876 | const int ib32 = il/2; | ||
| 877 | il = il%2; | ||
| 878 | // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16 | ||
| 879 | device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32; | ||
| 880 | const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4); | ||
| 881 | const float d = (float)xb->d * (ls - 32); | ||
| 882 | uint32_t aux32; | ||
| 883 | thread const uint8_t * q8 = (thread const uint8_t *)&aux32; | ||
| 884 | for (int i = 0; i < 4; ++i) { | ||
| 885 | aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f; | ||
| 886 | reg[i][0] = d * kvalues_iq4nl_f[q8[0]]; | ||
| 887 | reg[i][1] = d * kvalues_iq4nl_f[q8[1]]; | ||
| 888 | reg[i][2] = d * kvalues_iq4nl_f[q8[2]]; | ||
| 889 | reg[i][3] = d * kvalues_iq4nl_f[q8[3]]; | ||
| 890 | } | ||
| 891 | } | ||
| 892 | |||
| 893 | enum ggml_sort_order { | ||
| 894 | GGML_SORT_ORDER_ASC, | ||
| 895 | GGML_SORT_ORDER_DESC, | ||
| 896 | }; | ||
| 897 | |||
| 898 | constant float GELU_COEF_A = 0.044715f; | ||
| 899 | constant float GELU_QUICK_COEF = -1.702f; | ||
| 900 | constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; | ||
| 901 | constant float SQRT_2_INV = 0.70710678118654752440084436210484f; | ||
| 902 | |||
| 903 | // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation | ||
| 904 | // ref: https://www.johndcook.com/blog/python_erf/ | ||
| 905 | constant float p_erf = 0.3275911f; | ||
| 906 | constant float a1_erf = 0.254829592f; | ||
| 907 | constant float a2_erf = -0.284496736f; | ||
| 908 | constant float a3_erf = 1.421413741f; | ||
| 909 | constant float a4_erf = -1.453152027f; | ||
| 910 | constant float a5_erf = 1.061405429f; | ||
| 911 | |||
| 912 | template<typename T> | ||
| 913 | inline T erf_approx(T x) { | ||
| 914 | T sign_x = sign(x); | ||
| 915 | x = fabs(x); | ||
| 916 | T t = 1.0f / (1.0f + p_erf * x); | ||
| 917 | T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); | ||
| 918 | return sign_x * y; | ||
| 919 | } | ||
| 920 | |||
| 921 | template<typename T> T elu_approx(T x); | ||
| 922 | |||
| 923 | template<> inline float elu_approx<float>(float x) { | ||
| 924 | return (x > 0.f) ? x : (exp(x) - 1); | ||
| 925 | } | ||
| 926 | |||
| 927 | template<> inline float4 elu_approx<float4>(float4 x) { | ||
| 928 | float4 res; | ||
| 929 | |||
| 930 | res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); | ||
| 931 | res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); | ||
| 932 | res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); | ||
| 933 | res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); | ||
| 934 | |||
| 935 | return res; | ||
| 936 | } | ||
| 937 | |||
| 938 | constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; | ||
| 939 | constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; | ||
| 940 | |||
| 941 | template <typename T0, typename T, typename TC> | ||
| 942 | kernel void kernel_unary_impl( | ||
| 943 | constant ggml_metal_kargs_unary & args, | ||
| 944 | device const char * src0, | ||
| 945 | device char * dst, | ||
| 946 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 947 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 948 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 949 | #define FC_OP FC_unary_op | ||
| 950 | #define FC_CNT FC_unary_cnt | ||
| 951 | |||
| 952 | device const T0 * src0_ptr; | ||
| 953 | device T * dst_ptr; | ||
| 954 | |||
| 955 | int i0; | ||
| 956 | |||
| 957 | if (FC_CNT) { | ||
| 958 | i0 = tgpig.x; | ||
| 959 | |||
| 960 | src0_ptr = (device const T0 *) (src0); | ||
| 961 | dst_ptr = (device T *) (dst); | ||
| 962 | } else { | ||
| 963 | const int i03 = tgpig.z; | ||
| 964 | const int i02 = tgpig.y; | ||
| 965 | const int k0 = tgpig.x/args.ne01; | ||
| 966 | const int i01 = tgpig.x - k0*args.ne01; | ||
| 967 | |||
| 968 | i0 = k0*ntg.x + tpitg.x; | ||
| 969 | |||
| 970 | src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); | ||
| 971 | dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 ); | ||
| 972 | } | ||
| 973 | |||
| 974 | { | ||
| 975 | //threadgroup_barrier(mem_flags::mem_none); | ||
| 976 | |||
| 977 | if (!FC_CNT) { | ||
| 978 | if (i0 >= args.ne0) { | ||
| 979 | return; | ||
| 980 | } | ||
| 981 | } | ||
| 982 | |||
| 983 | const TC x = (TC) src0_ptr[i0]; | ||
| 984 | |||
| 985 | if (FC_OP == OP_UNARY_NUM_SCALE) { | ||
| 986 | dst_ptr[i0] = (T) (args.scale * x + args.bias); | ||
| 987 | } | ||
| 988 | |||
| 989 | if (FC_OP == OP_UNARY_NUM_FILL) { | ||
| 990 | dst_ptr[i0] = (T) args.val; | ||
| 991 | } | ||
| 992 | |||
| 993 | if (FC_OP == OP_UNARY_NUM_CLAMP) { | ||
| 994 | dst_ptr[i0] = (T) clamp(x, args.min, args.max); | ||
| 995 | } | ||
| 996 | |||
| 997 | if (FC_OP == OP_UNARY_NUM_SQR) { | ||
| 998 | dst_ptr[i0] = (T) (x * x); | ||
| 999 | } | ||
| 1000 | |||
| 1001 | if (FC_OP == OP_UNARY_NUM_SQRT) { | ||
| 1002 | dst_ptr[i0] = (T) sqrt(x); | ||
| 1003 | } | ||
| 1004 | |||
| 1005 | if (FC_OP == OP_UNARY_NUM_SIN) { | ||
| 1006 | dst_ptr[i0] = (T) sin(x); | ||
| 1007 | } | ||
| 1008 | |||
| 1009 | if (FC_OP == OP_UNARY_NUM_COS) { | ||
| 1010 | dst_ptr[i0] = (T) cos(x); | ||
| 1011 | } | ||
| 1012 | |||
| 1013 | if (FC_OP == OP_UNARY_NUM_LOG) { | ||
| 1014 | dst_ptr[i0] = (T) log(x); | ||
| 1015 | } | ||
| 1016 | |||
| 1017 | if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { | ||
| 1018 | dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope)); | ||
| 1019 | } | ||
| 1020 | |||
| 1021 | if (FC_OP == OP_UNARY_NUM_TANH) { | ||
| 1022 | dst_ptr[i0] = (T) precise::tanh(x); | ||
| 1023 | } | ||
| 1024 | |||
| 1025 | if (FC_OP == OP_UNARY_NUM_RELU) { | ||
| 1026 | dst_ptr[i0] = (T) fmax(0, x); | ||
| 1027 | } | ||
| 1028 | |||
| 1029 | if (FC_OP == OP_UNARY_NUM_SIGMOID) { | ||
| 1030 | dst_ptr[i0] = (T) (1 / (1 + exp(-x))); | ||
| 1031 | } | ||
| 1032 | |||
| 1033 | if (FC_OP == OP_UNARY_NUM_GELU) { | ||
| 1034 | dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x)))); | ||
| 1035 | } | ||
| 1036 | |||
| 1037 | if (FC_OP == OP_UNARY_NUM_GELU_ERF) { | ||
| 1038 | dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x))); | ||
| 1039 | } | ||
| 1040 | |||
| 1041 | if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { | ||
| 1042 | dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x)))); | ||
| 1043 | } | ||
| 1044 | |||
| 1045 | if (FC_OP == OP_UNARY_NUM_SILU) { | ||
| 1046 | dst_ptr[i0] = (T) (x / (1 + exp(-x))); | ||
| 1047 | } | ||
| 1048 | |||
| 1049 | if (FC_OP == OP_UNARY_NUM_ELU) { | ||
| 1050 | dst_ptr[i0] = (T) elu_approx(x); | ||
| 1051 | } | ||
| 1052 | |||
| 1053 | if (FC_OP == OP_UNARY_NUM_NEG) { | ||
| 1054 | dst_ptr[i0] = (T) -x; | ||
| 1055 | } | ||
| 1056 | |||
| 1057 | if (FC_OP == OP_UNARY_NUM_ABS) { | ||
| 1058 | dst_ptr[i0] = (T) fabs(x); | ||
| 1059 | } | ||
| 1060 | |||
| 1061 | if (FC_OP == OP_UNARY_NUM_SGN) { | ||
| 1062 | dst_ptr[i0] = T(x > 0) - T(x < 0); | ||
| 1063 | } | ||
| 1064 | |||
| 1065 | if (FC_OP == OP_UNARY_NUM_STEP) { | ||
| 1066 | dst_ptr[i0] = T(x > 0); | ||
| 1067 | } | ||
| 1068 | |||
| 1069 | if (FC_OP == OP_UNARY_NUM_HARDSWISH) { | ||
| 1070 | dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5))); | ||
| 1071 | } | ||
| 1072 | |||
| 1073 | if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { | ||
| 1074 | dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5)); | ||
| 1075 | } | ||
| 1076 | |||
| 1077 | if (FC_OP == OP_UNARY_NUM_EXP) { | ||
| 1078 | dst_ptr[i0] = (T) exp(x); | ||
| 1079 | } | ||
| 1080 | |||
| 1081 | if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { | ||
| 1082 | dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20); | ||
| 1083 | } | ||
| 1084 | |||
| 1085 | if (FC_OP == OP_UNARY_NUM_EXPM1) { | ||
| 1086 | // TODO: precise implementation | ||
| 1087 | dst_ptr[i0] = (T) (exp(x) - 1); | ||
| 1088 | } | ||
| 1089 | } | ||
| 1090 | |||
| 1091 | #undef FC_OP | ||
| 1092 | #undef FC_CNT | ||
| 1093 | } | ||
| 1094 | |||
| 1095 | typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t; | ||
| 1096 | |||
| 1097 | template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>; | ||
| 1098 | template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>; | ||
| 1099 | template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>; | ||
| 1100 | template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>; | ||
| 1101 | |||
| 1102 | // OP: 0 - add, 1 - sub, 2 - mul, 3 - div | ||
| 1103 | constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; | ||
| 1104 | constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; | ||
| 1105 | constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; | ||
| 1106 | |||
| 1107 | template <typename T0, typename T1, typename T> | ||
| 1108 | kernel void kernel_bin_fuse_impl( | ||
| 1109 | constant ggml_metal_kargs_bin & args, | ||
| 1110 | device const char * src0, | ||
| 1111 | device const char * src1, | ||
| 1112 | device char * dst, | ||
| 1113 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1114 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1115 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1116 | #define FC_OP FC_bin_op | ||
| 1117 | #define FC_F FC_bin_f | ||
| 1118 | #define FC_RB FC_bin_rb | ||
| 1119 | |||
| 1120 | if (FC_RB) { | ||
| 1121 | // row broadcast | ||
| 1122 | const uint i0 = tgpig.x; | ||
| 1123 | const uint i1 = i0%args.ne10; | ||
| 1124 | |||
| 1125 | device const T0 * src0_row = (device const T0 *) (src0); | ||
| 1126 | device T * dst_row = (device T *) (dst); | ||
| 1127 | |||
| 1128 | if (FC_F == 1) { | ||
| 1129 | device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]); | ||
| 1130 | |||
| 1131 | if (FC_OP == 0) { | ||
| 1132 | dst_row[i0] = src0_row[i0] + src1_row[i1]; | ||
| 1133 | } | ||
| 1134 | |||
| 1135 | if (FC_OP == 1) { | ||
| 1136 | dst_row[i0] = src0_row[i0] - src1_row[i1]; | ||
| 1137 | } | ||
| 1138 | |||
| 1139 | if (FC_OP == 2) { | ||
| 1140 | dst_row[i0] = src0_row[i0] * src1_row[i1]; | ||
| 1141 | } | ||
| 1142 | |||
| 1143 | if (FC_OP == 3) { | ||
| 1144 | dst_row[i0] = src0_row[i0] / src1_row[i1]; | ||
| 1145 | } | ||
| 1146 | } else { | ||
| 1147 | T0 res = src0_row[i0]; | ||
| 1148 | |||
| 1149 | if (FC_OP == 0) { | ||
| 1150 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1151 | res += ((device const T1 *) (src1 + args.o1[j]))[i1]; | ||
| 1152 | } | ||
| 1153 | } | ||
| 1154 | |||
| 1155 | if (FC_OP == 1) { | ||
| 1156 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1157 | res -= ((device const T1 *) (src1 + args.o1[j]))[i1]; | ||
| 1158 | } | ||
| 1159 | } | ||
| 1160 | |||
| 1161 | if (FC_OP == 2) { | ||
| 1162 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1163 | res *= ((device const T1 *) (src1 + args.o1[j]))[i1]; | ||
| 1164 | } | ||
| 1165 | } | ||
| 1166 | |||
| 1167 | if (FC_OP == 3) { | ||
| 1168 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1169 | res /= ((device const T1 *) (src1 + args.o1[j]))[i1]; | ||
| 1170 | } | ||
| 1171 | } | ||
| 1172 | |||
| 1173 | dst_row[i0] = res; | ||
| 1174 | } | ||
| 1175 | } else { | ||
| 1176 | const int i03 = tgpig.z; | ||
| 1177 | const int i02 = tgpig.y; | ||
| 1178 | const int i01 = tgpig.x; | ||
| 1179 | |||
| 1180 | if (i01 >= args.ne01) { | ||
| 1181 | return; | ||
| 1182 | } | ||
| 1183 | |||
| 1184 | const int i13 = i03%args.ne13; | ||
| 1185 | const int i12 = i02%args.ne12; | ||
| 1186 | const int i11 = i01%args.ne11; | ||
| 1187 | |||
| 1188 | device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); | ||
| 1189 | device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); | ||
| 1190 | |||
| 1191 | if (FC_F == 1) { | ||
| 1192 | device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); | ||
| 1193 | |||
| 1194 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 1195 | const int i10 = i0%args.ne10; | ||
| 1196 | |||
| 1197 | if (FC_OP == 0) { | ||
| 1198 | dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; | ||
| 1199 | } | ||
| 1200 | |||
| 1201 | if (FC_OP == 1) { | ||
| 1202 | dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; | ||
| 1203 | } | ||
| 1204 | |||
| 1205 | if (FC_OP == 2) { | ||
| 1206 | dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; | ||
| 1207 | } | ||
| 1208 | |||
| 1209 | if (FC_OP == 3) { | ||
| 1210 | dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; | ||
| 1211 | } | ||
| 1212 | } | ||
| 1213 | } else { | ||
| 1214 | device const T1 * src1_ptr[8]; | ||
| 1215 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1216 | src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); | ||
| 1217 | } | ||
| 1218 | |||
| 1219 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 1220 | const int i10 = i0%args.ne10; | ||
| 1221 | |||
| 1222 | T res = src0_ptr[i0]; | ||
| 1223 | |||
| 1224 | if (FC_OP == 0) { | ||
| 1225 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1226 | res += src1_ptr[j][i10]; | ||
| 1227 | } | ||
| 1228 | } | ||
| 1229 | |||
| 1230 | if (FC_OP == 1) { | ||
| 1231 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1232 | res -= src1_ptr[j][i10]; | ||
| 1233 | } | ||
| 1234 | } | ||
| 1235 | |||
| 1236 | if (FC_OP == 2) { | ||
| 1237 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1238 | res *= src1_ptr[j][i10]; | ||
| 1239 | } | ||
| 1240 | } | ||
| 1241 | |||
| 1242 | if (FC_OP == 3) { | ||
| 1243 | FOR_UNROLL (short j = 0; j < FC_F; ++j) { | ||
| 1244 | res /= src1_ptr[j][i10]; | ||
| 1245 | } | ||
| 1246 | } | ||
| 1247 | |||
| 1248 | dst_ptr[i0] = res; | ||
| 1249 | } | ||
| 1250 | } | ||
| 1251 | } | ||
| 1252 | |||
| 1253 | #undef FC_OP | ||
| 1254 | #undef FC_F | ||
| 1255 | #undef FC_RB | ||
| 1256 | } | ||
| 1257 | |||
| 1258 | typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t; | ||
| 1259 | |||
| 1260 | template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>; | ||
| 1261 | template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>; | ||
| 1262 | |||
| 1263 | kernel void kernel_add_id( | ||
| 1264 | constant ggml_metal_kargs_add_id & args, | ||
| 1265 | device const char * src0, | ||
| 1266 | device const char * src1, | ||
| 1267 | device const char * src2, | ||
| 1268 | device char * dst, | ||
| 1269 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1270 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1271 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1272 | const int i1 = tgpig.x; | ||
| 1273 | const int i2 = tgpig.y; | ||
| 1274 | |||
| 1275 | const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21)); | ||
| 1276 | |||
| 1277 | const size_t nb1 = args.ne0 * sizeof(float); | ||
| 1278 | const size_t nb2 = args.ne1 * nb1; | ||
| 1279 | |||
| 1280 | device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); | ||
| 1281 | device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); | ||
| 1282 | device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); | ||
| 1283 | |||
| 1284 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 1285 | dst_row[i0] = src0_row[i0] + src1_row[i0]; | ||
| 1286 | } | ||
| 1287 | } | ||
| 1288 | |||
| 1289 | template<typename T> | ||
| 1290 | kernel void kernel_repeat( | ||
| 1291 | constant ggml_metal_kargs_repeat & args, | ||
| 1292 | device const char * src0, | ||
| 1293 | device char * dst, | ||
| 1294 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1295 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1296 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1297 | const int i3 = tgpig.z; | ||
| 1298 | const int i2 = tgpig.y; | ||
| 1299 | const int i1 = tgpig.x; | ||
| 1300 | |||
| 1301 | const int i03 = i3%args.ne03; | ||
| 1302 | const int i02 = i2%args.ne02; | ||
| 1303 | const int i01 = i1%args.ne01; | ||
| 1304 | |||
| 1305 | device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; | ||
| 1306 | device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; | ||
| 1307 | |||
| 1308 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 1309 | const int i00 = i0%args.ne00; | ||
| 1310 | *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); | ||
| 1311 | } | ||
| 1312 | } | ||
| 1313 | |||
| 1314 | typedef decltype(kernel_repeat<float>) kernel_repeat_t; | ||
| 1315 | |||
| 1316 | template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>; | ||
| 1317 | template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>; | ||
| 1318 | template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>; | ||
| 1319 | template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>; | ||
| 1320 | |||
| 1321 | kernel void kernel_reglu_f32( | ||
| 1322 | constant ggml_metal_kargs_glu & args, | ||
| 1323 | device const char * src0, | ||
| 1324 | device const char * src1, | ||
| 1325 | device char * dst, | ||
| 1326 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 1327 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 1328 | uint ntg[[threads_per_threadgroup]]) { | ||
| 1329 | device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; | ||
| 1330 | device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; | ||
| 1331 | device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); | ||
| 1332 | |||
| 1333 | for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { | ||
| 1334 | const float x0 = src0_row[i0]; | ||
| 1335 | const float x1 = src1_row[i0]; | ||
| 1336 | |||
| 1337 | dst_row[i0] = x0*x1*(x0 > 0.0f); | ||
| 1338 | } | ||
| 1339 | } | ||
| 1340 | |||
| 1341 | kernel void kernel_geglu_f32( | ||
| 1342 | constant ggml_metal_kargs_glu & args, | ||
| 1343 | device const char * src0, | ||
| 1344 | device const char * src1, | ||
| 1345 | device char * dst, | ||
| 1346 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 1347 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 1348 | uint ntg[[threads_per_threadgroup]]) { | ||
| 1349 | device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; | ||
| 1350 | device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; | ||
| 1351 | device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); | ||
| 1352 | |||
| 1353 | for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { | ||
| 1354 | const float x0 = src0_row[i0]; | ||
| 1355 | const float x1 = src1_row[i0]; | ||
| 1356 | |||
| 1357 | const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); | ||
| 1358 | |||
| 1359 | dst_row[i0] = gelu*x1; | ||
| 1360 | } | ||
| 1361 | } | ||
| 1362 | |||
| 1363 | kernel void kernel_swiglu_f32( | ||
| 1364 | constant ggml_metal_kargs_glu & args, | ||
| 1365 | device const char * src0, | ||
| 1366 | device const char * src1, | ||
| 1367 | device char * dst, | ||
| 1368 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 1369 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 1370 | uint ntg[[threads_per_threadgroup]]) { | ||
| 1371 | device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; | ||
| 1372 | device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; | ||
| 1373 | device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); | ||
| 1374 | |||
| 1375 | for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { | ||
| 1376 | const float x0 = src0_row[i0]; | ||
| 1377 | const float x1 = src1_row[i0]; | ||
| 1378 | |||
| 1379 | const float silu = x0 / (1.0f + exp(-x0)); | ||
| 1380 | |||
| 1381 | dst_row[i0] = silu*x1; | ||
| 1382 | } | ||
| 1383 | } | ||
| 1384 | |||
| 1385 | kernel void kernel_swiglu_oai_f32( | ||
| 1386 | constant ggml_metal_kargs_glu & args, | ||
| 1387 | device const char * src0, | ||
| 1388 | device const char * src1, | ||
| 1389 | device char * dst, | ||
| 1390 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 1391 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 1392 | uint ntg[[threads_per_threadgroup]]) { | ||
| 1393 | device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; | ||
| 1394 | device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; | ||
| 1395 | device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); | ||
| 1396 | |||
| 1397 | for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { | ||
| 1398 | float x0 = src0_row[i0]; | ||
| 1399 | float x1 = src1_row[i0]; | ||
| 1400 | |||
| 1401 | x0 = min(x0, args.limit); | ||
| 1402 | x1 = max(min(x1, args.limit), -args.limit); | ||
| 1403 | |||
| 1404 | float out_glu = x0 / (1.0f + exp(-x0 * args.alpha)); | ||
| 1405 | out_glu = out_glu * (1.0f + x1); | ||
| 1406 | |||
| 1407 | dst_row[i0] = out_glu; | ||
| 1408 | } | ||
| 1409 | } | ||
| 1410 | |||
| 1411 | kernel void kernel_geglu_erf_f32( | ||
| 1412 | constant ggml_metal_kargs_glu & args, | ||
| 1413 | device const char * src0, | ||
| 1414 | device const char * src1, | ||
| 1415 | device char * dst, | ||
| 1416 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 1417 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 1418 | uint ntg[[threads_per_threadgroup]]) { | ||
| 1419 | device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; | ||
| 1420 | device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; | ||
| 1421 | device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); | ||
| 1422 | |||
| 1423 | for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { | ||
| 1424 | const float x0 = src0_row[i0]; | ||
| 1425 | const float x1 = src1_row[i0]; | ||
| 1426 | |||
| 1427 | const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV)); | ||
| 1428 | |||
| 1429 | dst_row[i0] = gelu_erf*x1; | ||
| 1430 | } | ||
| 1431 | } | ||
| 1432 | |||
| 1433 | kernel void kernel_geglu_quick_f32( | ||
| 1434 | constant ggml_metal_kargs_glu & args, | ||
| 1435 | device const char * src0, | ||
| 1436 | device const char * src1, | ||
| 1437 | device char * dst, | ||
| 1438 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 1439 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 1440 | uint ntg[[threads_per_threadgroup]]) { | ||
| 1441 | device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; | ||
| 1442 | device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; | ||
| 1443 | device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); | ||
| 1444 | |||
| 1445 | for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { | ||
| 1446 | const float x0 = src0_row[i0]; | ||
| 1447 | const float x1 = src1_row[i0]; | ||
| 1448 | |||
| 1449 | const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); | ||
| 1450 | |||
| 1451 | dst_row[i0] = gelu_quick*x1; | ||
| 1452 | } | ||
| 1453 | } | ||
| 1454 | |||
| 1455 | kernel void kernel_op_sum_f32( | ||
| 1456 | constant ggml_metal_kargs_sum & args, | ||
| 1457 | device const float * src0, | ||
| 1458 | device float * dst, | ||
| 1459 | threadgroup float * shmem_f32 [[threadgroup(0)]], | ||
| 1460 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1461 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1462 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 1463 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 1464 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1465 | |||
| 1466 | if (args.np == 0) { | ||
| 1467 | return; | ||
| 1468 | } | ||
| 1469 | |||
| 1470 | // TODO: become function constant | ||
| 1471 | const uint nsg = (ntg.x + 31) / 32; | ||
| 1472 | |||
| 1473 | float sumf = 0; | ||
| 1474 | |||
| 1475 | for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { | ||
| 1476 | sumf += src0[i0]; | ||
| 1477 | } | ||
| 1478 | |||
| 1479 | sumf = simd_sum(sumf); | ||
| 1480 | |||
| 1481 | if (tiisg == 0) { | ||
| 1482 | shmem_f32[sgitg] = sumf; | ||
| 1483 | } | ||
| 1484 | |||
| 1485 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1486 | |||
| 1487 | float total = 0; | ||
| 1488 | |||
| 1489 | if (sgitg == 0) { | ||
| 1490 | float v = 0; | ||
| 1491 | |||
| 1492 | if (tpitg.x < nsg) { | ||
| 1493 | v = shmem_f32[tpitg.x]; | ||
| 1494 | } | ||
| 1495 | |||
| 1496 | total = simd_sum(v); | ||
| 1497 | |||
| 1498 | if (tpitg.x == 0) { | ||
| 1499 | dst[0] = total; | ||
| 1500 | } | ||
| 1501 | } | ||
| 1502 | } | ||
| 1503 | |||
| 1504 | template <bool norm> | ||
| 1505 | kernel void kernel_sum_rows( | ||
| 1506 | constant ggml_metal_kargs_sum_rows & args, | ||
| 1507 | device const float * src0, | ||
| 1508 | device float * dst, | ||
| 1509 | threadgroup float * shmem_f32 [[threadgroup(0)]], | ||
| 1510 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1511 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1512 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 1513 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 1514 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1515 | int64_t i3 = tgpig.z; | ||
| 1516 | int64_t i2 = tgpig.y; | ||
| 1517 | int64_t i1 = tgpig.x; | ||
| 1518 | |||
| 1519 | if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { | ||
| 1520 | return; | ||
| 1521 | } | ||
| 1522 | |||
| 1523 | if (sgitg == 0) { | ||
| 1524 | shmem_f32[tiisg] = 0.0f; | ||
| 1525 | } | ||
| 1526 | |||
| 1527 | device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); | ||
| 1528 | device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); | ||
| 1529 | |||
| 1530 | float sumf = 0; | ||
| 1531 | |||
| 1532 | for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { | ||
| 1533 | sumf += src_row[i0]; | ||
| 1534 | } | ||
| 1535 | |||
| 1536 | sumf = simd_sum(sumf); | ||
| 1537 | |||
| 1538 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1539 | |||
| 1540 | if (tiisg == 0) { | ||
| 1541 | shmem_f32[sgitg] = sumf; | ||
| 1542 | } | ||
| 1543 | |||
| 1544 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1545 | |||
| 1546 | sumf = shmem_f32[tiisg]; | ||
| 1547 | sumf = simd_sum(sumf); | ||
| 1548 | |||
| 1549 | if (tpitg.x == 0) { | ||
| 1550 | dst_row[0] = norm ? sumf / args.ne00 : sumf; | ||
| 1551 | } | ||
| 1552 | } | ||
| 1553 | |||
| 1554 | typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t; | ||
| 1555 | |||
| 1556 | template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>; | ||
| 1557 | template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>; | ||
| 1558 | |||
| 1559 | template<typename T> | ||
| 1560 | kernel void kernel_cumsum_blk( | ||
| 1561 | constant ggml_metal_kargs_cumsum_blk & args, | ||
| 1562 | device const char * src0, | ||
| 1563 | device char * tmp, | ||
| 1564 | device char * dst, | ||
| 1565 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 1566 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1567 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1568 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 1569 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 1570 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1571 | const int ib = tgpig[0]/args.ne01; | ||
| 1572 | |||
| 1573 | const int i00 = ib*ntg.x; | ||
| 1574 | const int i01 = tgpig[0]%args.ne01; | ||
| 1575 | const int i02 = tgpig[1]; | ||
| 1576 | const int i03 = tgpig[2]; | ||
| 1577 | |||
| 1578 | device const float * src0_row = (device const float *) (src0 + | ||
| 1579 | args.nb01*i01 + | ||
| 1580 | args.nb02*i02 + | ||
| 1581 | args.nb03*i03); | ||
| 1582 | |||
| 1583 | threadgroup float * shmem_f32 = (threadgroup float *) shmem; | ||
| 1584 | |||
| 1585 | float v = 0.0f; | ||
| 1586 | |||
| 1587 | if (i00 + tpitg.x < args.ne00) { | ||
| 1588 | v = src0_row[i00 + tpitg.x]; | ||
| 1589 | } | ||
| 1590 | |||
| 1591 | float s = simd_prefix_inclusive_sum(v); | ||
| 1592 | |||
| 1593 | if (tiisg == N_SIMDWIDTH - 1) { | ||
| 1594 | shmem_f32[sgitg] = s; | ||
| 1595 | } | ||
| 1596 | |||
| 1597 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1598 | |||
| 1599 | if (sgitg == 0) { | ||
| 1600 | shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]); | ||
| 1601 | } | ||
| 1602 | |||
| 1603 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1604 | |||
| 1605 | s += shmem_f32[sgitg]; | ||
| 1606 | |||
| 1607 | device float * dst_row = (device float *) dst + | ||
| 1608 | args.ne00*i01 + | ||
| 1609 | args.ne00*args.ne01*i02 + | ||
| 1610 | args.ne00*args.ne01*args.ne02*i03; | ||
| 1611 | |||
| 1612 | if (i00 + tpitg.x < args.ne00) { | ||
| 1613 | dst_row[i00 + tpitg.x] = s; | ||
| 1614 | } | ||
| 1615 | |||
| 1616 | if (args.outb && tpitg.x == ntg.x - 1) { | ||
| 1617 | device float * tmp_row = (device float *) tmp + | ||
| 1618 | args.net0*i01 + | ||
| 1619 | args.net0*args.net1*i02 + | ||
| 1620 | args.net0*args.net1*args.net2*i03; | ||
| 1621 | |||
| 1622 | tmp_row[ib] = s; | ||
| 1623 | } | ||
| 1624 | } | ||
| 1625 | |||
| 1626 | typedef decltype(kernel_cumsum_blk<float>) kernel_cumsum_blk_t; | ||
| 1627 | |||
| 1628 | template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk<float>; | ||
| 1629 | |||
| 1630 | template<typename T> | ||
| 1631 | kernel void kernel_cumsum_add( | ||
| 1632 | constant ggml_metal_kargs_cumsum_add & args, | ||
| 1633 | device const char * tmp, | ||
| 1634 | device char * dst, | ||
| 1635 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1636 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1637 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 1638 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 1639 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1640 | const int ib = tgpig[0]/args.ne01; | ||
| 1641 | |||
| 1642 | if (ib == 0) { | ||
| 1643 | return; | ||
| 1644 | } | ||
| 1645 | |||
| 1646 | const int i00 = ib*ntg.x; | ||
| 1647 | const int i01 = tgpig[0]%args.ne01; | ||
| 1648 | const int i02 = tgpig[1]; | ||
| 1649 | const int i03 = tgpig[2]; | ||
| 1650 | |||
| 1651 | device const float * tmp_row = (device const float *) (tmp + | ||
| 1652 | args.nbt1*i01 + | ||
| 1653 | args.nbt2*i02 + | ||
| 1654 | args.nbt3*i03); | ||
| 1655 | |||
| 1656 | device float * dst_row = (device float *) dst + | ||
| 1657 | args.ne00*i01 + | ||
| 1658 | args.ne00*args.ne01*i02 + | ||
| 1659 | args.ne00*args.ne01*args.ne02*i03; | ||
| 1660 | |||
| 1661 | if (i00 + tpitg.x < args.ne00) { | ||
| 1662 | dst_row[i00 + tpitg.x] += tmp_row[ib - 1]; | ||
| 1663 | } | ||
| 1664 | } | ||
| 1665 | |||
| 1666 | typedef decltype(kernel_cumsum_add<float>) kernel_cumsum_add_t; | ||
| 1667 | |||
| 1668 | template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add<float>; | ||
| 1669 | |||
| 1670 | |||
| 1671 | template<uint32_t ttype> | ||
| 1672 | bool _ggml_vec_tri_cmp(const int i, const int r); | ||
| 1673 | |||
| 1674 | template<> | ||
| 1675 | bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER */ 3>(const int i, const int r) { | ||
| 1676 | return i < r; | ||
| 1677 | } | ||
| 1678 | |||
| 1679 | template<> | ||
| 1680 | bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_LOWER_DIAG */ 2>(const int i, const int r) { | ||
| 1681 | return i <= r; | ||
| 1682 | } | ||
| 1683 | |||
| 1684 | template<> | ||
| 1685 | bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER */ 1>(const int i, const int r) { | ||
| 1686 | return i > r; | ||
| 1687 | } | ||
| 1688 | |||
| 1689 | template<> | ||
| 1690 | bool _ggml_vec_tri_cmp</* GGML_TRI_TYPE_UPPER_DIAG */ 0>(const int i, const int r) { | ||
| 1691 | return i >= r; | ||
| 1692 | } | ||
| 1693 | |||
| 1694 | template<typename T, int ttype> | ||
| 1695 | kernel void kernel_tri( | ||
| 1696 | constant ggml_metal_kargs_tri & args, | ||
| 1697 | device const char * src0, | ||
| 1698 | device const char * dst, | ||
| 1699 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1700 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 1701 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 1702 | const int i3 = tgpig.z; | ||
| 1703 | const int i2 = tgpig.y; | ||
| 1704 | const int i1 = tgpig.x; | ||
| 1705 | |||
| 1706 | if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { | ||
| 1707 | return; | ||
| 1708 | } | ||
| 1709 | |||
| 1710 | device const T * src_row = (device const T *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); | ||
| 1711 | device T * dst_row = (device T *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); | ||
| 1712 | |||
| 1713 | // Each thread is a single element of the row if ne00 < max threads per | ||
| 1714 | // threadgroup, so this will loop once for each index that this thread is | ||
| 1715 | // responsible for | ||
| 1716 | for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { | ||
| 1717 | // Use the comparison as a mask for branchless | ||
| 1718 | dst_row[i0] = static_cast<T>(_ggml_vec_tri_cmp<ttype>(i0, i1)) * src_row[i0]; | ||
| 1719 | } | ||
| 1720 | } | ||
| 1721 | |||
| 1722 | typedef decltype(kernel_tri<float, 0>) kernel_tri_t; | ||
| 1723 | |||
| 1724 | template [[host_name("kernel_tri_f32_0")]] kernel kernel_tri_t kernel_tri<float, 0>; | ||
| 1725 | template [[host_name("kernel_tri_f32_1")]] kernel kernel_tri_t kernel_tri<float, 1>; | ||
| 1726 | template [[host_name("kernel_tri_f32_2")]] kernel kernel_tri_t kernel_tri<float, 2>; | ||
| 1727 | template [[host_name("kernel_tri_f32_3")]] kernel kernel_tri_t kernel_tri<float, 3>; | ||
| 1728 | template [[host_name("kernel_tri_f16_0")]] kernel kernel_tri_t kernel_tri<half, 0>; | ||
| 1729 | template [[host_name("kernel_tri_f16_1")]] kernel kernel_tri_t kernel_tri<half, 1>; | ||
| 1730 | template [[host_name("kernel_tri_f16_2")]] kernel kernel_tri_t kernel_tri<half, 2>; | ||
| 1731 | template [[host_name("kernel_tri_f16_3")]] kernel kernel_tri_t kernel_tri<half, 3>; | ||
| 1732 | #if defined(GGML_METAL_HAS_BF16) | ||
| 1733 | template [[host_name("kernel_tri_bf16_0")]] kernel kernel_tri_t kernel_tri<bfloat, 0>; | ||
| 1734 | template [[host_name("kernel_tri_bf16_1")]] kernel kernel_tri_t kernel_tri<bfloat, 1>; | ||
| 1735 | template [[host_name("kernel_tri_bf16_2")]] kernel kernel_tri_t kernel_tri<bfloat, 2>; | ||
| 1736 | template [[host_name("kernel_tri_bf16_3")]] kernel kernel_tri_t kernel_tri<bfloat, 3>; | ||
| 1737 | #endif | ||
| 1738 | |||
| 1739 | template<typename T> | ||
| 1740 | kernel void kernel_soft_max( | ||
| 1741 | constant ggml_metal_kargs_soft_max & args, | ||
| 1742 | device const char * src0, | ||
| 1743 | device const char * src1, | ||
| 1744 | device const char * src2, | ||
| 1745 | device char * dst, | ||
| 1746 | threadgroup float * buf [[threadgroup(0)]], | ||
| 1747 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1748 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 1749 | uint sgitg[[simdgroup_index_in_threadgroup]], | ||
| 1750 | uint tiisg[[thread_index_in_simdgroup]], | ||
| 1751 | uint3 tptg[[threads_per_threadgroup]]) { | ||
| 1752 | const int32_t i03 = tgpig.z; | ||
| 1753 | const int32_t i02 = tgpig.y; | ||
| 1754 | const int32_t i01 = tgpig.x; | ||
| 1755 | |||
| 1756 | const int32_t i13 = i03%args.ne13; | ||
| 1757 | const int32_t i12 = i02%args.ne12; | ||
| 1758 | const int32_t i11 = i01; | ||
| 1759 | |||
| 1760 | device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); | ||
| 1761 | device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; | ||
| 1762 | device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr; | ||
| 1763 | device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); | ||
| 1764 | |||
| 1765 | float slope = 1.0f; | ||
| 1766 | |||
| 1767 | // ALiBi | ||
| 1768 | if (args.max_bias > 0.0f) { | ||
| 1769 | const int32_t h = i02; | ||
| 1770 | |||
| 1771 | const float base = h < args.n_head_log2 ? args.m0 : args.m1; | ||
| 1772 | const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; | ||
| 1773 | |||
| 1774 | slope = pow(base, exp); | ||
| 1775 | } | ||
| 1776 | |||
| 1777 | // parallel max | ||
| 1778 | float lmax = psrc2 ? psrc2[i02] : -INFINITY; | ||
| 1779 | |||
| 1780 | for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { | ||
| 1781 | lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)); | ||
| 1782 | } | ||
| 1783 | |||
| 1784 | // find the max value in the block | ||
| 1785 | float max_val = simd_max(lmax); | ||
| 1786 | if (tptg.x > N_SIMDWIDTH) { | ||
| 1787 | if (sgitg == 0) { | ||
| 1788 | buf[tiisg] = -INFINITY; | ||
| 1789 | } | ||
| 1790 | |||
| 1791 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1792 | |||
| 1793 | if (tiisg == 0) { | ||
| 1794 | buf[sgitg] = max_val; | ||
| 1795 | } | ||
| 1796 | |||
| 1797 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1798 | |||
| 1799 | max_val = buf[tiisg]; | ||
| 1800 | max_val = simd_max(max_val); | ||
| 1801 | } | ||
| 1802 | |||
| 1803 | // parallel sum | ||
| 1804 | float lsum = 0.0f; | ||
| 1805 | for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { | ||
| 1806 | const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val); | ||
| 1807 | lsum += exp_psrc0; | ||
| 1808 | pdst[i00] = exp_psrc0; | ||
| 1809 | } | ||
| 1810 | |||
| 1811 | // This barrier fixes a failing test | ||
| 1812 | // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 | ||
| 1813 | threadgroup_barrier(mem_flags::mem_none); | ||
| 1814 | |||
| 1815 | float sum = simd_sum(lsum); | ||
| 1816 | |||
| 1817 | if (tptg.x > N_SIMDWIDTH) { | ||
| 1818 | if (sgitg == 0) { | ||
| 1819 | buf[tiisg] = 0.0f; | ||
| 1820 | } | ||
| 1821 | |||
| 1822 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1823 | |||
| 1824 | if (tiisg == 0) { | ||
| 1825 | buf[sgitg] = sum; | ||
| 1826 | } | ||
| 1827 | |||
| 1828 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1829 | |||
| 1830 | sum = buf[tiisg]; | ||
| 1831 | sum = simd_sum(sum); | ||
| 1832 | } | ||
| 1833 | |||
| 1834 | if (psrc2) { | ||
| 1835 | sum += exp(psrc2[i02] - max_val); | ||
| 1836 | } | ||
| 1837 | |||
| 1838 | const float inv_sum = 1.0f/sum; | ||
| 1839 | |||
| 1840 | for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) { | ||
| 1841 | pdst[i00] *= inv_sum; | ||
| 1842 | } | ||
| 1843 | } | ||
| 1844 | |||
| 1845 | template<typename T> | ||
| 1846 | kernel void kernel_soft_max_4( | ||
| 1847 | constant ggml_metal_kargs_soft_max & args, | ||
| 1848 | device const char * src0, | ||
| 1849 | device const char * src1, | ||
| 1850 | device const char * src2, | ||
| 1851 | device char * dst, | ||
| 1852 | threadgroup float * buf [[threadgroup(0)]], | ||
| 1853 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1854 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 1855 | uint sgitg[[simdgroup_index_in_threadgroup]], | ||
| 1856 | uint tiisg[[thread_index_in_simdgroup]], | ||
| 1857 | uint3 tptg[[threads_per_threadgroup]]) { | ||
| 1858 | const int32_t i03 = tgpig.z; | ||
| 1859 | const int32_t i02 = tgpig.y; | ||
| 1860 | const int32_t i01 = tgpig.x; | ||
| 1861 | |||
| 1862 | const int32_t i13 = i03%args.ne13; | ||
| 1863 | const int32_t i12 = i02%args.ne12; | ||
| 1864 | const int32_t i11 = i01; | ||
| 1865 | |||
| 1866 | device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); | ||
| 1867 | device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr; | ||
| 1868 | device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr; | ||
| 1869 | device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3); | ||
| 1870 | |||
| 1871 | float slope = 1.0f; | ||
| 1872 | |||
| 1873 | if (args.max_bias > 0.0f) { | ||
| 1874 | const int32_t h = i02; | ||
| 1875 | |||
| 1876 | const float base = h < args.n_head_log2 ? args.m0 : args.m1; | ||
| 1877 | const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; | ||
| 1878 | |||
| 1879 | slope = pow(base, exp); | ||
| 1880 | } | ||
| 1881 | |||
| 1882 | // parallel max | ||
| 1883 | float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY; | ||
| 1884 | |||
| 1885 | for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { | ||
| 1886 | lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))); | ||
| 1887 | } | ||
| 1888 | |||
| 1889 | const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); | ||
| 1890 | |||
| 1891 | float max_val = simd_max(lmax); | ||
| 1892 | if (tptg.x > N_SIMDWIDTH) { | ||
| 1893 | if (sgitg == 0) { | ||
| 1894 | buf[tiisg] = -INFINITY; | ||
| 1895 | } | ||
| 1896 | |||
| 1897 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1898 | |||
| 1899 | if (tiisg == 0) { | ||
| 1900 | buf[sgitg] = max_val; | ||
| 1901 | } | ||
| 1902 | |||
| 1903 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1904 | |||
| 1905 | max_val = buf[tiisg]; | ||
| 1906 | max_val = simd_max(max_val); | ||
| 1907 | } | ||
| 1908 | |||
| 1909 | // parallel sum | ||
| 1910 | float4 lsum4 = 0.0f; | ||
| 1911 | for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { | ||
| 1912 | const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val); | ||
| 1913 | lsum4 += exp_psrc4; | ||
| 1914 | pdst4[i00] = exp_psrc4; | ||
| 1915 | } | ||
| 1916 | |||
| 1917 | const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; | ||
| 1918 | |||
| 1919 | // This barrier fixes a failing test | ||
| 1920 | // ref: https://github.com/ggml-org/ggml/pull/621#discussion_r1425156335 | ||
| 1921 | threadgroup_barrier(mem_flags::mem_none); | ||
| 1922 | |||
| 1923 | float sum = simd_sum(lsum); | ||
| 1924 | |||
| 1925 | if (tptg.x > N_SIMDWIDTH) { | ||
| 1926 | if (sgitg == 0) { | ||
| 1927 | buf[tiisg] = 0.0f; | ||
| 1928 | } | ||
| 1929 | |||
| 1930 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1931 | |||
| 1932 | if (tiisg == 0) { | ||
| 1933 | buf[sgitg] = sum; | ||
| 1934 | } | ||
| 1935 | |||
| 1936 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 1937 | |||
| 1938 | sum = buf[tiisg]; | ||
| 1939 | sum = simd_sum(sum); | ||
| 1940 | } | ||
| 1941 | |||
| 1942 | if (psrc2) { | ||
| 1943 | sum += exp(psrc2[i02] - max_val); | ||
| 1944 | } | ||
| 1945 | |||
| 1946 | const float inv_sum = 1.0f/sum; | ||
| 1947 | |||
| 1948 | for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) { | ||
| 1949 | pdst4[i00] *= inv_sum; | ||
| 1950 | } | ||
| 1951 | } | ||
| 1952 | |||
| 1953 | typedef decltype(kernel_soft_max<float>) kernel_soft_max_t; | ||
| 1954 | typedef decltype(kernel_soft_max_4<float4>) kernel_soft_max_4_t; | ||
| 1955 | |||
| 1956 | template [[host_name("kernel_soft_max_f16")]] kernel kernel_soft_max_t kernel_soft_max<half>; | ||
| 1957 | template [[host_name("kernel_soft_max_f32")]] kernel kernel_soft_max_t kernel_soft_max<float>; | ||
| 1958 | template [[host_name("kernel_soft_max_f16_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<half4>; | ||
| 1959 | template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kernel_soft_max_4<float4>; | ||
| 1960 | |||
| 1961 | // ref: ggml.c:ggml_compute_forward_ssm_conv_f32 | ||
| 1962 | kernel void kernel_ssm_conv_f32_f32( | ||
| 1963 | constant ggml_metal_kargs_ssm_conv & args, | ||
| 1964 | device const void * src0, | ||
| 1965 | device const void * src1, | ||
| 1966 | device float * dst, | ||
| 1967 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1968 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 1969 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 1970 | const int64_t ir = tgpig.x; | ||
| 1971 | const int64_t i2 = tgpig.y; | ||
| 1972 | const int64_t i3 = tgpig.z; | ||
| 1973 | |||
| 1974 | const int64_t nc = args.ne10; | ||
| 1975 | //const int64_t ncs = args.ne00; | ||
| 1976 | //const int64_t nr = args.ne01; | ||
| 1977 | //const int64_t n_t = args.ne1; | ||
| 1978 | //const int64_t n_s = args.ne2; | ||
| 1979 | |||
| 1980 | device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); | ||
| 1981 | device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); | ||
| 1982 | device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); | ||
| 1983 | |||
| 1984 | float sumf = 0.0f; | ||
| 1985 | |||
| 1986 | for (int64_t i0 = 0; i0 < nc; ++i0) { | ||
| 1987 | sumf += s[i0] * c[i0]; | ||
| 1988 | } | ||
| 1989 | |||
| 1990 | x[0] = sumf; | ||
| 1991 | } | ||
| 1992 | |||
| 1993 | kernel void kernel_ssm_conv_f32_f32_4( | ||
| 1994 | constant ggml_metal_kargs_ssm_conv & args, | ||
| 1995 | device const void * src0, | ||
| 1996 | device const void * src1, | ||
| 1997 | device float * dst, | ||
| 1998 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 1999 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 2000 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 2001 | const int64_t ir = tgpig.x; | ||
| 2002 | const int64_t i2 = tgpig.y; | ||
| 2003 | const int64_t i3 = tgpig.z; | ||
| 2004 | |||
| 2005 | const int64_t nc = args.ne10; | ||
| 2006 | //const int64_t ncs = args.ne00; | ||
| 2007 | //const int64_t nr = args.ne01; | ||
| 2008 | //const int64_t n_t = args.ne1; | ||
| 2009 | //const int64_t n_s = args.ne2; | ||
| 2010 | |||
| 2011 | device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); | ||
| 2012 | device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); | ||
| 2013 | device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); | ||
| 2014 | |||
| 2015 | float sumf = 0.0f; | ||
| 2016 | |||
| 2017 | for (int64_t i0 = 0; i0 < nc/4; ++i0) { | ||
| 2018 | sumf += dot(s[i0], c[i0]); | ||
| 2019 | } | ||
| 2020 | |||
| 2021 | x[0] = sumf; | ||
| 2022 | } | ||
| 2023 | |||
| 2024 | constant short FC_ssm_conv_bs [[function_constant(FC_SSM_CONV + 0)]]; | ||
| 2025 | |||
| 2026 | // Batched version: each threadgroup processes multiple tokens for better efficiency | ||
| 2027 | // Thread layout: each thread handles one token, threadgroup covers BATCH_SIZE tokens | ||
| 2028 | kernel void kernel_ssm_conv_f32_f32_batched( | ||
| 2029 | constant ggml_metal_kargs_ssm_conv & args, | ||
| 2030 | device const void * src0, | ||
| 2031 | device const void * src1, | ||
| 2032 | device float * dst, | ||
| 2033 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2034 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 2035 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 2036 | // tgpig.x = row index (ir) | ||
| 2037 | // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) | ||
| 2038 | // tgpig.z = sequence index (i3) | ||
| 2039 | // tpitg.x = thread within batch (0..BATCH_SIZE-1) | ||
| 2040 | const short BATCH_SIZE = FC_ssm_conv_bs; | ||
| 2041 | |||
| 2042 | const int64_t ir = tgpig.x; | ||
| 2043 | const int64_t i2_base = tgpig.y * BATCH_SIZE; | ||
| 2044 | const int64_t i3 = tgpig.z; | ||
| 2045 | const int64_t i2_off = tpitg.x; | ||
| 2046 | const int64_t i2 = i2_base + i2_off; | ||
| 2047 | |||
| 2048 | const int64_t nc = args.ne10; // conv kernel size (typically 4) | ||
| 2049 | const int64_t n_t = args.ne1; // number of tokens | ||
| 2050 | |||
| 2051 | // Bounds check for partial batches at the end | ||
| 2052 | if (i2 >= n_t) { | ||
| 2053 | return; | ||
| 2054 | } | ||
| 2055 | |||
| 2056 | // Load conv weights (shared across all tokens for this row) | ||
| 2057 | device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11); | ||
| 2058 | |||
| 2059 | // Load source for this specific token | ||
| 2060 | device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); | ||
| 2061 | |||
| 2062 | // Output location for this token | ||
| 2063 | device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); | ||
| 2064 | |||
| 2065 | float sumf = 0.0f; | ||
| 2066 | for (int64_t i0 = 0; i0 < nc; ++i0) { | ||
| 2067 | sumf += s[i0] * c[i0]; | ||
| 2068 | } | ||
| 2069 | |||
| 2070 | x[0] = sumf; | ||
| 2071 | } | ||
| 2072 | |||
| 2073 | kernel void kernel_ssm_conv_f32_f32_batched_4( | ||
| 2074 | constant ggml_metal_kargs_ssm_conv & args, | ||
| 2075 | device const void * src0, | ||
| 2076 | device const void * src1, | ||
| 2077 | device float * dst, | ||
| 2078 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2079 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 2080 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 2081 | // tgpig.x = row index (ir) | ||
| 2082 | // tgpig.y = batch of tokens (i2_base / BATCH_SIZE) | ||
| 2083 | // tgpig.z = sequence index (i3) | ||
| 2084 | // tpitg.x = thread within batch (0..BATCH_SIZE-1) | ||
| 2085 | const short BATCH_SIZE = FC_ssm_conv_bs; | ||
| 2086 | |||
| 2087 | const int64_t ir = tgpig.x; | ||
| 2088 | const int64_t i2_base = tgpig.y * BATCH_SIZE; | ||
| 2089 | const int64_t i3 = tgpig.z; | ||
| 2090 | const int64_t i2_off = tpitg.x; | ||
| 2091 | const int64_t i2 = i2_base + i2_off; | ||
| 2092 | |||
| 2093 | const int64_t nc = args.ne10; // conv kernel size (typically 4) | ||
| 2094 | const int64_t n_t = args.ne1; // number of tokens | ||
| 2095 | |||
| 2096 | // Bounds check for partial batches at the end | ||
| 2097 | if (i2 >= n_t) { | ||
| 2098 | return; | ||
| 2099 | } | ||
| 2100 | |||
| 2101 | // Load conv weights (shared across all tokens for this row) | ||
| 2102 | device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); | ||
| 2103 | |||
| 2104 | // Load source for this specific token | ||
| 2105 | device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); | ||
| 2106 | |||
| 2107 | // Output location for this token | ||
| 2108 | device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); | ||
| 2109 | |||
| 2110 | float sumf = 0.0f; | ||
| 2111 | for (int64_t i0 = 0; i0 < nc/4; ++i0) { | ||
| 2112 | sumf += dot(s[i0], c[i0]); | ||
| 2113 | } | ||
| 2114 | |||
| 2115 | x[0] = sumf; | ||
| 2116 | } | ||
| 2117 | |||
| 2118 | // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part | ||
| 2119 | // Optimized version: reduces redundant memory loads by having one thread load shared values | ||
| 2120 | kernel void kernel_ssm_scan_f32( | ||
| 2121 | constant ggml_metal_kargs_ssm_scan & args, | ||
| 2122 | device const void * src0, | ||
| 2123 | device const void * src1, | ||
| 2124 | device const void * src2, | ||
| 2125 | device const void * src3, | ||
| 2126 | device const void * src4, | ||
| 2127 | device const void * src5, | ||
| 2128 | device const void * src6, | ||
| 2129 | device float * dst, | ||
| 2130 | threadgroup float * shared [[threadgroup(0)]], | ||
| 2131 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2132 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 2133 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 2134 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 2135 | ushort sgptg[[simdgroups_per_threadgroup]], | ||
| 2136 | uint3 tgpg[[threadgroups_per_grid]]) { | ||
| 2137 | constexpr short NW = N_SIMDWIDTH; | ||
| 2138 | |||
| 2139 | // Shared memory layout: | ||
| 2140 | // [0..sgptg*NW-1]: partial sums for reduction (existing) | ||
| 2141 | // [sgptg*NW..sgptg*NW+sgptg-1]: pre-computed x_dt values for each token in batch | ||
| 2142 | // [sgptg*NW+sgptg..sgptg*NW+2*sgptg-1]: pre-computed dA values for each token in batch | ||
| 2143 | threadgroup float * shared_sums = shared; | ||
| 2144 | threadgroup float * shared_x_dt = shared + sgptg * NW; | ||
| 2145 | threadgroup float * shared_dA = shared + sgptg * NW + sgptg; | ||
| 2146 | |||
| 2147 | shared_sums[tpitg.x] = 0.0f; | ||
| 2148 | |||
| 2149 | const int32_t i0 = tpitg.x; | ||
| 2150 | const int32_t i1 = tgpig.x; | ||
| 2151 | const int32_t ir = tgpig.y; // current head | ||
| 2152 | const int32_t i3 = tgpig.z; // current seq | ||
| 2153 | |||
| 2154 | const int32_t nc = args.d_state; | ||
| 2155 | const int32_t nr = args.d_inner; | ||
| 2156 | const int32_t nh = args.n_head; | ||
| 2157 | const int32_t ng = args.n_group; | ||
| 2158 | const int32_t n_t = args.n_seq_tokens; | ||
| 2159 | |||
| 2160 | const int32_t s_off = args.s_off; | ||
| 2161 | |||
| 2162 | device const int32_t * ids = (device const int32_t *) src6; | ||
| 2163 | |||
| 2164 | device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); | ||
| 2165 | device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); | ||
| 2166 | |||
| 2167 | const int32_t i = i0 + i1*nc; | ||
| 2168 | const int32_t g = ir / (nh / ng); // repeat_interleave | ||
| 2169 | |||
| 2170 | float s0 = s0_buff[i]; | ||
| 2171 | float s = 0.0f; | ||
| 2172 | |||
| 2173 | device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} | ||
| 2174 | |||
| 2175 | const float A0 = A[i0%args.ne30]; | ||
| 2176 | |||
| 2177 | device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} | ||
| 2178 | device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} | ||
| 2179 | device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} | ||
| 2180 | device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} | ||
| 2181 | |||
| 2182 | device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} | ||
| 2183 | |||
| 2184 | for (int i2 = 0; i2 < n_t; i2 += sgptg) { | ||
| 2185 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2186 | |||
| 2187 | // Pre-compute x_dt and dA for this batch of tokens | ||
| 2188 | // Only first sgptg threads do the loads and expensive math | ||
| 2189 | if (i0 < sgptg && i2 + i0 < n_t) { | ||
| 2190 | // ns12 and ns21 are element strides (nb12/nb10, nb21/nb20) | ||
| 2191 | device const float * x_t = x + i0 * args.ns12; | ||
| 2192 | device const float * dt_t = dt + i0 * args.ns21; | ||
| 2193 | |||
| 2194 | const float dt0 = dt_t[0]; | ||
| 2195 | const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; | ||
| 2196 | shared_x_dt[i0] = x_t[0] * dtsp; | ||
| 2197 | shared_dA[i0] = dtsp; // Store dtsp, compute exp(dtsp * A0) per-thread since A0 varies | ||
| 2198 | } | ||
| 2199 | |||
| 2200 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2201 | |||
| 2202 | for (int t = 0; t < sgptg && i2 + t < n_t; t++) { | ||
| 2203 | const float x_dt = shared_x_dt[t]; | ||
| 2204 | const float dA = exp(shared_dA[t] * A0); | ||
| 2205 | |||
| 2206 | s = (s0 * dA) + (B[i0] * x_dt); | ||
| 2207 | |||
| 2208 | const float sumf = simd_sum(s * C[i0]); | ||
| 2209 | |||
| 2210 | if (tiisg == 0) { | ||
| 2211 | shared_sums[t*NW + sgitg] = sumf; | ||
| 2212 | } | ||
| 2213 | |||
| 2214 | // recurse | ||
| 2215 | s0 = s; | ||
| 2216 | |||
| 2217 | B += args.ns42; | ||
| 2218 | C += args.ns52; | ||
| 2219 | } | ||
| 2220 | |||
| 2221 | // Advance pointers for next batch | ||
| 2222 | x += sgptg * args.ns12; | ||
| 2223 | dt += sgptg * args.ns21; | ||
| 2224 | |||
| 2225 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2226 | |||
| 2227 | const float sumf = simd_sum(shared_sums[sgitg*NW + tiisg]); | ||
| 2228 | |||
| 2229 | if (tiisg == 0 && i2 + sgitg < n_t) { | ||
| 2230 | y[sgitg*nh*nr] = sumf; | ||
| 2231 | } | ||
| 2232 | |||
| 2233 | y += sgptg*nh*nr; | ||
| 2234 | } | ||
| 2235 | |||
| 2236 | s_buff[i] = s; | ||
| 2237 | } | ||
| 2238 | |||
| 2239 | kernel void kernel_rwkv_wkv6_f32( | ||
| 2240 | device const float * k, | ||
| 2241 | device const float * v, | ||
| 2242 | device const float * r, | ||
| 2243 | device const float * tf, | ||
| 2244 | device const float * td, | ||
| 2245 | device const float * state_in, | ||
| 2246 | device float * dst, | ||
| 2247 | constant uint & B, | ||
| 2248 | constant uint & T, | ||
| 2249 | constant uint & C, | ||
| 2250 | constant uint & H, | ||
| 2251 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2252 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 2253 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 2254 | |||
| 2255 | const uint head_size = 64; // TODO: support head_size = 128 | ||
| 2256 | const uint batch_id = tgpig.x / H; | ||
| 2257 | const uint head_id = tgpig.x % H; | ||
| 2258 | const uint tid = tpitg.x; | ||
| 2259 | |||
| 2260 | if (batch_id >= B || head_id >= H) { | ||
| 2261 | return; | ||
| 2262 | } | ||
| 2263 | |||
| 2264 | const uint state_size = C * head_size; | ||
| 2265 | const uint n_seq_tokens = T / B; | ||
| 2266 | |||
| 2267 | threadgroup float _k[head_size]; | ||
| 2268 | threadgroup float _r[head_size]; | ||
| 2269 | threadgroup float _tf[head_size]; | ||
| 2270 | threadgroup float _td[head_size]; | ||
| 2271 | |||
| 2272 | float state[head_size]; | ||
| 2273 | |||
| 2274 | for (uint i = 0; i < head_size; i++) { | ||
| 2275 | state[i] = state_in[batch_id * state_size + head_id * head_size * head_size | ||
| 2276 | + i * head_size + tid]; | ||
| 2277 | } | ||
| 2278 | |||
| 2279 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2280 | _tf[tid] = tf[head_id * head_size + tid]; | ||
| 2281 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2282 | |||
| 2283 | const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; | ||
| 2284 | const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; | ||
| 2285 | |||
| 2286 | for (uint t = start_t; t < end_t; t += C) { | ||
| 2287 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2288 | _k[tid] = k[t]; | ||
| 2289 | _r[tid] = r[t]; | ||
| 2290 | _td[tid] = td[t]; | ||
| 2291 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2292 | |||
| 2293 | const float v_val = v[t]; | ||
| 2294 | float y = 0.0; | ||
| 2295 | |||
| 2296 | for (uint j = 0; j < head_size; j += 4) { | ||
| 2297 | float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); | ||
| 2298 | float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); | ||
| 2299 | float4 tf_vec = float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); | ||
| 2300 | float4 td_vec = float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); | ||
| 2301 | float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); | ||
| 2302 | |||
| 2303 | float4 kv = k_vec * v_val; | ||
| 2304 | |||
| 2305 | float4 temp = tf_vec * kv + s_vec; | ||
| 2306 | y += dot(r_vec, temp); | ||
| 2307 | |||
| 2308 | s_vec = s_vec * td_vec + kv; | ||
| 2309 | state[j] = s_vec[0]; | ||
| 2310 | state[j+1] = s_vec[1]; | ||
| 2311 | state[j+2] = s_vec[2]; | ||
| 2312 | state[j+3] = s_vec[3]; | ||
| 2313 | } | ||
| 2314 | |||
| 2315 | dst[t] = y; | ||
| 2316 | } | ||
| 2317 | |||
| 2318 | for (uint i = 0; i < head_size; i++) { | ||
| 2319 | dst[T * C + batch_id * state_size + head_id * head_size * head_size | ||
| 2320 | + i * head_size + tid] = state[i]; | ||
| 2321 | } | ||
| 2322 | } | ||
| 2323 | |||
| 2324 | kernel void kernel_rwkv_wkv7_f32( | ||
| 2325 | device const float * r, | ||
| 2326 | device const float * w, | ||
| 2327 | device const float * k, | ||
| 2328 | device const float * v, | ||
| 2329 | device const float * a, | ||
| 2330 | device const float * b, | ||
| 2331 | device const float * state_in, | ||
| 2332 | device float * dst, | ||
| 2333 | constant uint & B, | ||
| 2334 | constant uint & T, | ||
| 2335 | constant uint & C, | ||
| 2336 | constant uint & H, | ||
| 2337 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2338 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 2339 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 2340 | |||
| 2341 | const uint head_size = 64; // TODO: support head_size = 128 | ||
| 2342 | const uint batch_id = tgpig.x / H; | ||
| 2343 | const uint head_id = tgpig.x % H; | ||
| 2344 | const uint tid = tpitg.x; | ||
| 2345 | |||
| 2346 | if (batch_id >= B || head_id >= H) { | ||
| 2347 | return; | ||
| 2348 | } | ||
| 2349 | |||
| 2350 | const uint state_size = C * head_size; | ||
| 2351 | const uint n_seq_tokens = T / B; | ||
| 2352 | |||
| 2353 | threadgroup float _r[head_size]; | ||
| 2354 | threadgroup float _w[head_size]; | ||
| 2355 | threadgroup float _k[head_size]; | ||
| 2356 | threadgroup float _a[head_size]; | ||
| 2357 | threadgroup float _b[head_size]; | ||
| 2358 | |||
| 2359 | float state[head_size]; | ||
| 2360 | |||
| 2361 | for (uint i = 0; i < head_size; i++) { | ||
| 2362 | state[i] = state_in[batch_id * state_size + head_id * head_size * head_size | ||
| 2363 | + tid * head_size + i]; | ||
| 2364 | } | ||
| 2365 | |||
| 2366 | const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; | ||
| 2367 | const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; | ||
| 2368 | |||
| 2369 | for (uint t = start_t; t < end_t; t += C) { | ||
| 2370 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2371 | _r[tid] = r[t]; | ||
| 2372 | _w[tid] = w[t]; | ||
| 2373 | _k[tid] = k[t]; | ||
| 2374 | _a[tid] = a[t]; | ||
| 2375 | _b[tid] = b[t]; | ||
| 2376 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2377 | |||
| 2378 | const float v_val = v[t]; | ||
| 2379 | float y = 0.0, sa = 0.0; | ||
| 2380 | |||
| 2381 | float4 sa_vec(0.0); | ||
| 2382 | |||
| 2383 | for (uint j = 0; j < head_size; j += 4) { | ||
| 2384 | float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); | ||
| 2385 | float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); | ||
| 2386 | sa_vec += a_vec * s_vec; | ||
| 2387 | } | ||
| 2388 | sa = sa_vec[0] + sa_vec[1] + sa_vec[2] + sa_vec[3]; | ||
| 2389 | |||
| 2390 | for (uint j = 0; j < head_size; j += 4) { | ||
| 2391 | float4 r_vec = float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); | ||
| 2392 | float4 w_vec = float4(_w[j], _w[j+1], _w[j+2], _w[j+3]); | ||
| 2393 | float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); | ||
| 2394 | float4 b_vec = float4(_b[j], _b[j+1], _b[j+2], _b[j+3]); | ||
| 2395 | float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); | ||
| 2396 | |||
| 2397 | float4 kv = k_vec * v_val; | ||
| 2398 | |||
| 2399 | s_vec = s_vec * w_vec + kv + sa * b_vec; | ||
| 2400 | y += dot(s_vec, r_vec); | ||
| 2401 | |||
| 2402 | state[j] = s_vec[0]; | ||
| 2403 | state[j+1] = s_vec[1]; | ||
| 2404 | state[j+2] = s_vec[2]; | ||
| 2405 | state[j+3] = s_vec[3]; | ||
| 2406 | } | ||
| 2407 | |||
| 2408 | dst[t] = y; | ||
| 2409 | } | ||
| 2410 | |||
| 2411 | for (uint i = 0; i < head_size; i++) { | ||
| 2412 | dst[T * C + batch_id * state_size + head_id * head_size * head_size | ||
| 2413 | + tid * head_size + i] = state[i]; | ||
| 2414 | } | ||
| 2415 | } | ||
| 2416 | |||
| 2417 | constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; | ||
| 2418 | constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; | ||
| 2419 | constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; | ||
| 2420 | |||
| 2421 | kernel void kernel_solve_tri_f32( | ||
| 2422 | constant ggml_metal_kargs_solve_tri & args, | ||
| 2423 | device const char * src0, | ||
| 2424 | device const char * src1, | ||
| 2425 | device char * dst, | ||
| 2426 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 2427 | ushort3 tgpig[[threadgroup_position_in_grid]], | ||
| 2428 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 2429 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 2430 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 2431 | constexpr short NW = N_SIMDWIDTH; | ||
| 2432 | |||
| 2433 | const short NSG = FC_solve_tri_nsg; | ||
| 2434 | const short N = FC_solve_tri_n; | ||
| 2435 | const short K = FC_solve_tri_k; | ||
| 2436 | const short NP = PAD2(N, NW); | ||
| 2437 | |||
| 2438 | const int32_t ne02 = args.ne02; | ||
| 2439 | const int32_t ne03 = args.ne03; | ||
| 2440 | |||
| 2441 | const int32_t i03 = tgpig.z; | ||
| 2442 | const int32_t i02 = tgpig.y; | ||
| 2443 | const int32_t i01 = tgpig.x*NSG + sgitg; | ||
| 2444 | |||
| 2445 | threadgroup float * sh0 = (threadgroup float *) shmem; | ||
| 2446 | |||
| 2447 | device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N; | ||
| 2448 | device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01; | ||
| 2449 | device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01; | ||
| 2450 | |||
| 2451 | for (short rr = 0; rr < N; rr += NSG) { | ||
| 2452 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2453 | |||
| 2454 | { | ||
| 2455 | threadgroup float * sh0_cur = sh0 + sgitg*NP; | ||
| 2456 | |||
| 2457 | for (short t = 0; t*NW < N; ++t) { | ||
| 2458 | const short idx = t*NW + tiisg; | ||
| 2459 | sh0_cur[idx] = src0_ptr[idx]; | ||
| 2460 | } | ||
| 2461 | |||
| 2462 | src0_ptr += NSG*N; | ||
| 2463 | } | ||
| 2464 | |||
| 2465 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2466 | |||
| 2467 | if (i01 >= args.ne10) { | ||
| 2468 | continue; | ||
| 2469 | } | ||
| 2470 | |||
| 2471 | for (short ir = 0; ir < NSG && rr + ir < N; ++ir) { | ||
| 2472 | const short r = rr + ir; | ||
| 2473 | |||
| 2474 | threadgroup float * sh0_cur = sh0 + ir*NP; | ||
| 2475 | |||
| 2476 | float sum = 0.0f; | ||
| 2477 | |||
| 2478 | for (short t = 0; t*NW < r; ++t) { | ||
| 2479 | const short idx = t*NW + tiisg; | ||
| 2480 | sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r); | ||
| 2481 | } | ||
| 2482 | |||
| 2483 | sum = simd_sum(sum); | ||
| 2484 | |||
| 2485 | if (tiisg == 0) { | ||
| 2486 | const float diag = sh0_cur[r]; | ||
| 2487 | |||
| 2488 | dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag; | ||
| 2489 | } | ||
| 2490 | } | ||
| 2491 | } | ||
| 2492 | } | ||
| 2493 | |||
| 2494 | kernel void kernel_argmax_f32( | ||
| 2495 | constant ggml_metal_kargs_argmax & args, | ||
| 2496 | device const char * src0, | ||
| 2497 | device char * dst, | ||
| 2498 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 2499 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 2500 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 2501 | uint sgitg[[simdgroup_index_in_threadgroup]], | ||
| 2502 | uint tiisg[[thread_index_in_simdgroup]], | ||
| 2503 | uint ntg[[threads_per_threadgroup]]) { | ||
| 2504 | device const float * x_row = (device const float *) ((device const char *) src0 + tgpig * args.nb01); | ||
| 2505 | |||
| 2506 | float lmax = -INFINITY; | ||
| 2507 | int32_t larg = -1; | ||
| 2508 | |||
| 2509 | for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) { | ||
| 2510 | if (x_row[i00] > lmax) { | ||
| 2511 | lmax = x_row[i00]; | ||
| 2512 | larg = i00; | ||
| 2513 | } | ||
| 2514 | } | ||
| 2515 | |||
| 2516 | // find the argmax value in the block | ||
| 2517 | float max_val = simd_max(lmax); | ||
| 2518 | int32_t arg_val = simd_max(select(-1, larg, lmax == max_val)); | ||
| 2519 | |||
| 2520 | device int32_t * dst_i32 = (device int32_t *) dst; | ||
| 2521 | |||
| 2522 | threadgroup float * shared_maxval = (threadgroup float *) shmem; | ||
| 2523 | threadgroup int32_t * shared_argmax = (threadgroup int32_t *) shmem + N_SIMDWIDTH; | ||
| 2524 | |||
| 2525 | if (ntg > N_SIMDWIDTH) { | ||
| 2526 | if (sgitg == 0) { | ||
| 2527 | shared_maxval[tiisg] = -INFINITY; | ||
| 2528 | shared_argmax[tiisg] = -1; | ||
| 2529 | } | ||
| 2530 | |||
| 2531 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2532 | |||
| 2533 | if (tiisg == 0) { | ||
| 2534 | shared_maxval[sgitg] = max_val; | ||
| 2535 | shared_argmax[sgitg] = arg_val; | ||
| 2536 | } | ||
| 2537 | |||
| 2538 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2539 | |||
| 2540 | max_val = shared_maxval[tiisg]; | ||
| 2541 | arg_val = shared_argmax[tiisg]; | ||
| 2542 | |||
| 2543 | float max_val_reduced = simd_max(max_val); | ||
| 2544 | int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced)); | ||
| 2545 | |||
| 2546 | dst_i32[tgpig] = arg_val_reduced; | ||
| 2547 | |||
| 2548 | return; | ||
| 2549 | } | ||
| 2550 | |||
| 2551 | dst_i32[tgpig] = arg_val; | ||
| 2552 | } | ||
| 2553 | |||
| 2554 | // F == 1 : norm (no fuse) | ||
| 2555 | // F == 2 : norm + mul | ||
| 2556 | // F == 3 : norm + mul + add | ||
| 2557 | template <typename T, short F> | ||
| 2558 | kernel void kernel_norm_fuse_impl( | ||
| 2559 | constant ggml_metal_kargs_norm & args, | ||
| 2560 | device const char * src0, | ||
| 2561 | device const char * src1_0, | ||
| 2562 | device const char * src1_1, | ||
| 2563 | device char * dst, | ||
| 2564 | threadgroup float * shmem_f32 [[threadgroup(0)]], | ||
| 2565 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2566 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 2567 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 2568 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 2569 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 2570 | if (sgitg == 0) { | ||
| 2571 | shmem_f32[tiisg] = 0.0f; | ||
| 2572 | } | ||
| 2573 | |||
| 2574 | const int i01 = tgpig.x; | ||
| 2575 | const int i02 = tgpig.y; | ||
| 2576 | const int i03 = tgpig.z; | ||
| 2577 | |||
| 2578 | device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); | ||
| 2579 | |||
| 2580 | device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); | ||
| 2581 | device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); | ||
| 2582 | |||
| 2583 | T sumft(0.0f); | ||
| 2584 | |||
| 2585 | float sumf = 0.0f; | ||
| 2586 | |||
| 2587 | for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { | ||
| 2588 | sumft += x[i00]; | ||
| 2589 | } | ||
| 2590 | sumf = dot(sumft, T(1.0f)); | ||
| 2591 | sumf = simd_sum(sumf); | ||
| 2592 | |||
| 2593 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2594 | |||
| 2595 | if (tiisg == 0) { | ||
| 2596 | shmem_f32[sgitg] = sumf; | ||
| 2597 | } | ||
| 2598 | |||
| 2599 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2600 | |||
| 2601 | sumf = shmem_f32[tiisg]; | ||
| 2602 | sumf = simd_sum(sumf); | ||
| 2603 | |||
| 2604 | const float mean = sumf/args.ne00; | ||
| 2605 | |||
| 2606 | device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); | ||
| 2607 | |||
| 2608 | sumf = 0.0f; | ||
| 2609 | for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { | ||
| 2610 | y[i00] = x[i00] - mean; | ||
| 2611 | sumf += dot(y[i00], y[i00]); | ||
| 2612 | } | ||
| 2613 | sumf = simd_sum(sumf); | ||
| 2614 | |||
| 2615 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2616 | |||
| 2617 | if (tiisg == 0) { | ||
| 2618 | shmem_f32[sgitg] = sumf; | ||
| 2619 | } | ||
| 2620 | |||
| 2621 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2622 | |||
| 2623 | sumf = shmem_f32[tiisg]; | ||
| 2624 | sumf = simd_sum(sumf); | ||
| 2625 | |||
| 2626 | const float variance = sumf/args.ne00; | ||
| 2627 | |||
| 2628 | const float scale = 1.0f/sqrt(variance + args.eps); | ||
| 2629 | for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { | ||
| 2630 | if (F == 1) { | ||
| 2631 | y[i00] = (y[i00]*scale); | ||
| 2632 | } | ||
| 2633 | if (F == 2) { | ||
| 2634 | y[i00] = (y[i00]*scale)*f0[i00]; | ||
| 2635 | } | ||
| 2636 | if (F == 3) { | ||
| 2637 | y[i00] = (y[i00]*scale)*f0[i00] + f1[i00]; | ||
| 2638 | } | ||
| 2639 | } | ||
| 2640 | } | ||
| 2641 | |||
| 2642 | typedef decltype(kernel_norm_fuse_impl<float4, 1>) kernel_norm_fuse_t; | ||
| 2643 | |||
| 2644 | template [[host_name("kernel_norm_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 1>; | ||
| 2645 | template [[host_name("kernel_norm_mul_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 2>; | ||
| 2646 | template [[host_name("kernel_norm_mul_add_f32")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float, 3>; | ||
| 2647 | |||
| 2648 | template [[host_name("kernel_norm_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 1>; | ||
| 2649 | template [[host_name("kernel_norm_mul_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 2>; | ||
| 2650 | template [[host_name("kernel_norm_mul_add_f32_4")]] kernel kernel_norm_fuse_t kernel_norm_fuse_impl<float4, 3>; | ||
| 2651 | |||
| 2652 | // F == 1 : rms_norm (no fuse) | ||
| 2653 | // F == 2 : rms_norm + mul | ||
| 2654 | // F == 3 : rms_norm + mul + add | ||
| 2655 | template <typename T, short F> | ||
| 2656 | kernel void kernel_rms_norm_fuse_impl( | ||
| 2657 | constant ggml_metal_kargs_norm & args, | ||
| 2658 | device const char * src0, | ||
| 2659 | device const char * src1_0, | ||
| 2660 | device const char * src1_1, | ||
| 2661 | device char * dst, | ||
| 2662 | threadgroup float * shmem_f32 [[threadgroup(0)]], | ||
| 2663 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2664 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 2665 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 2666 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 2667 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 2668 | if (sgitg == 0) { | ||
| 2669 | shmem_f32[tiisg] = 0.0f; | ||
| 2670 | } | ||
| 2671 | |||
| 2672 | const int i01 = tgpig.x; | ||
| 2673 | const int i02 = tgpig.y; | ||
| 2674 | const int i03 = tgpig.z; | ||
| 2675 | |||
| 2676 | device const T * x = (device const T *) (src0 + i03*args.nbf3[0] + i02*args.nbf2[0] + i01*args.nbf1[0]); | ||
| 2677 | |||
| 2678 | device const T * f0 = (device const T *) (src1_0 + (i03%args.nef3[1])*args.nbf3[1] + (i02%args.nef2[1])*args.nbf2[1] + (i01%args.nef1[1])*args.nbf1[1]); | ||
| 2679 | device const T * f1 = (device const T *) (src1_1 + (i03%args.nef3[2])*args.nbf3[2] + (i02%args.nef2[2])*args.nbf2[2] + (i01%args.nef1[2])*args.nbf1[2]); | ||
| 2680 | |||
| 2681 | float sumf = 0.0f; | ||
| 2682 | |||
| 2683 | // parallel sum | ||
| 2684 | for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { | ||
| 2685 | sumf += dot(x[i00], x[i00]); | ||
| 2686 | } | ||
| 2687 | sumf = simd_sum(sumf); | ||
| 2688 | |||
| 2689 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2690 | |||
| 2691 | if (tiisg == 0) { | ||
| 2692 | shmem_f32[sgitg] = sumf; | ||
| 2693 | } | ||
| 2694 | |||
| 2695 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2696 | |||
| 2697 | sumf = shmem_f32[tiisg]; | ||
| 2698 | sumf = simd_sum(sumf); | ||
| 2699 | |||
| 2700 | const float mean = sumf/args.ne00; | ||
| 2701 | const float scale = 1.0f/sqrt(mean + args.eps); | ||
| 2702 | |||
| 2703 | device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); | ||
| 2704 | for (int i00 = tpitg.x; i00 < args.ne00_t; i00 += ntg.x) { | ||
| 2705 | if (F == 1) { | ||
| 2706 | y[i00] = (x[i00]*scale); | ||
| 2707 | } | ||
| 2708 | if (F == 2) { | ||
| 2709 | y[i00] = (x[i00]*scale)*f0[i00]; | ||
| 2710 | } | ||
| 2711 | if (F == 3) { | ||
| 2712 | y[i00] = (x[i00]*scale)*f0[i00] + f1[i00]; | ||
| 2713 | } | ||
| 2714 | } | ||
| 2715 | } | ||
| 2716 | |||
| 2717 | typedef decltype(kernel_rms_norm_fuse_impl<float4, 1>) kernel_rms_norm_fuse_t; | ||
| 2718 | |||
| 2719 | template [[host_name("kernel_rms_norm_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 1>; | ||
| 2720 | template [[host_name("kernel_rms_norm_mul_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 2>; | ||
| 2721 | template [[host_name("kernel_rms_norm_mul_add_f32")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float, 3>; | ||
| 2722 | |||
| 2723 | template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 1>; | ||
| 2724 | template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>; | ||
| 2725 | template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>; | ||
| 2726 | |||
| 2727 | template <typename T0, typename T> | ||
| 2728 | kernel void kernel_l2_norm_impl( | ||
| 2729 | constant ggml_metal_kargs_l2_norm & args, | ||
| 2730 | device const char * src0, | ||
| 2731 | device char * dst, | ||
| 2732 | threadgroup float * shmem_f32 [[threadgroup(0)]], | ||
| 2733 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 2734 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 2735 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 2736 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 2737 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 2738 | const int i03 = tgpig.z; | ||
| 2739 | const int i02 = tgpig.y; | ||
| 2740 | const int i01 = tgpig.x; | ||
| 2741 | |||
| 2742 | if (sgitg == 0) { | ||
| 2743 | shmem_f32[tiisg] = 0.0f; | ||
| 2744 | } | ||
| 2745 | |||
| 2746 | device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); | ||
| 2747 | device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); | ||
| 2748 | |||
| 2749 | float sumf = 0.0f; | ||
| 2750 | |||
| 2751 | // parallel sum | ||
| 2752 | for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { | ||
| 2753 | sumf += dot(x[i00], x[i00]); | ||
| 2754 | } | ||
| 2755 | sumf = simd_sum(sumf); | ||
| 2756 | |||
| 2757 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2758 | |||
| 2759 | if (tiisg == 0) { | ||
| 2760 | shmem_f32[sgitg] = sumf; | ||
| 2761 | } | ||
| 2762 | |||
| 2763 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2764 | |||
| 2765 | sumf = shmem_f32[tiisg]; | ||
| 2766 | sumf = simd_sum(sumf); | ||
| 2767 | |||
| 2768 | const float scale = 1.0f/sqrt(max(sumf, args.eps)); | ||
| 2769 | |||
| 2770 | for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { | ||
| 2771 | y[i00] = x[i00] * scale; | ||
| 2772 | } | ||
| 2773 | } | ||
| 2774 | |||
| 2775 | typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t; | ||
| 2776 | |||
| 2777 | template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>; | ||
| 2778 | template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>; | ||
| 2779 | |||
| 2780 | kernel void kernel_group_norm_f32( | ||
| 2781 | constant ggml_metal_kargs_group_norm & args, | ||
| 2782 | device const float * src0, | ||
| 2783 | device float * dst, | ||
| 2784 | threadgroup float * buf [[threadgroup(0)]], | ||
| 2785 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 2786 | uint tpitg[[thread_position_in_threadgroup]], | ||
| 2787 | uint sgitg[[simdgroup_index_in_threadgroup]], | ||
| 2788 | uint tiisg[[thread_index_in_simdgroup]], | ||
| 2789 | uint ntg[[threads_per_threadgroup]]) { | ||
| 2790 | const int64_t ne = args.ne00*args.ne01*args.ne02; | ||
| 2791 | const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.ngrp - 1) / args.ngrp); | ||
| 2792 | |||
| 2793 | int start = tgpig * gs; | ||
| 2794 | int end = start + gs; | ||
| 2795 | |||
| 2796 | start += tpitg; | ||
| 2797 | |||
| 2798 | if (end >= ne) { | ||
| 2799 | end = ne; | ||
| 2800 | } | ||
| 2801 | |||
| 2802 | float tmp = 0.0f; // partial sum for thread in warp | ||
| 2803 | |||
| 2804 | for (int j = start; j < end; j += ntg) { | ||
| 2805 | tmp += src0[j]; | ||
| 2806 | } | ||
| 2807 | |||
| 2808 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2809 | tmp = simd_sum(tmp); | ||
| 2810 | if (ntg > N_SIMDWIDTH) { | ||
| 2811 | if (sgitg == 0) { | ||
| 2812 | buf[tiisg] = 0.0f; | ||
| 2813 | } | ||
| 2814 | |||
| 2815 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2816 | |||
| 2817 | if (tiisg == 0) { | ||
| 2818 | buf[sgitg] = tmp; | ||
| 2819 | } | ||
| 2820 | |||
| 2821 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2822 | |||
| 2823 | tmp = buf[tiisg]; | ||
| 2824 | tmp = simd_sum(tmp); | ||
| 2825 | } | ||
| 2826 | |||
| 2827 | const float mean = tmp / gs; | ||
| 2828 | tmp = 0.0f; | ||
| 2829 | |||
| 2830 | for (int j = start; j < end; j += ntg) { | ||
| 2831 | float xi = src0[j] - mean; | ||
| 2832 | dst[j] = xi; | ||
| 2833 | tmp += xi * xi; | ||
| 2834 | } | ||
| 2835 | |||
| 2836 | tmp = simd_sum(tmp); | ||
| 2837 | if (ntg > N_SIMDWIDTH) { | ||
| 2838 | if (sgitg == 0) { | ||
| 2839 | buf[tiisg] = 0.0f; | ||
| 2840 | } | ||
| 2841 | |||
| 2842 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2843 | |||
| 2844 | if (tiisg == 0) { | ||
| 2845 | buf[sgitg] = tmp; | ||
| 2846 | } | ||
| 2847 | |||
| 2848 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2849 | |||
| 2850 | tmp = buf[tiisg]; | ||
| 2851 | tmp = simd_sum(tmp); | ||
| 2852 | } | ||
| 2853 | |||
| 2854 | const float variance = tmp / gs; | ||
| 2855 | const float scale = 1.0f/sqrt(variance + args.eps); | ||
| 2856 | for (int j = start; j < end; j += ntg) { | ||
| 2857 | dst[j] *= scale; | ||
| 2858 | } | ||
| 2859 | } | ||
| 2860 | |||
| 2861 | // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) | ||
| 2862 | // il indicates where the q4 quants begin (0 or QK4_0/4) | ||
| 2863 | // we assume that the yl's have been multiplied with the appropriate scale factor | ||
| 2864 | // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | ||
| 2865 | inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { | ||
| 2866 | float d = qb_curr->d; | ||
| 2867 | |||
| 2868 | float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; | ||
| 2869 | |||
| 2870 | device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); | ||
| 2871 | |||
| 2872 | for (int i = 0; i < 8; i += 2) { | ||
| 2873 | acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); | ||
| 2874 | acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); | ||
| 2875 | acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); | ||
| 2876 | acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); | ||
| 2877 | } | ||
| 2878 | |||
| 2879 | return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); | ||
| 2880 | } | ||
| 2881 | |||
| 2882 | // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) | ||
| 2883 | // il indicates where the q4 quants begin (0 or QK4_0/4) | ||
| 2884 | // we assume that the yl's have been multiplied with the appropriate scale factor | ||
| 2885 | // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | ||
| 2886 | inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { | ||
| 2887 | float d = qb_curr->d; | ||
| 2888 | float m = qb_curr->m; | ||
| 2889 | |||
| 2890 | float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; | ||
| 2891 | |||
| 2892 | device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2); | ||
| 2893 | |||
| 2894 | for (int i = 0; i < 8; i+=2) { | ||
| 2895 | acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); | ||
| 2896 | acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); | ||
| 2897 | acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); | ||
| 2898 | acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); | ||
| 2899 | } | ||
| 2900 | |||
| 2901 | return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; | ||
| 2902 | } | ||
| 2903 | |||
| 2904 | // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) | ||
| 2905 | // il indicates where the q5 quants begin (0 or QK5_0/4) | ||
| 2906 | // we assume that the yl's have been multiplied with the appropriate scale factor | ||
| 2907 | // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | ||
| 2908 | inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { | ||
| 2909 | float d = qb_curr->d; | ||
| 2910 | |||
| 2911 | float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; | ||
| 2912 | |||
| 2913 | device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); | ||
| 2914 | const uint32_t qh = *((device const uint32_t *)qb_curr->qh); | ||
| 2915 | |||
| 2916 | for (int i = 0; i < 8; i+=2) { | ||
| 2917 | acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); | ||
| 2918 | acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); | ||
| 2919 | acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); | ||
| 2920 | acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); | ||
| 2921 | } | ||
| 2922 | |||
| 2923 | return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]); | ||
| 2924 | } | ||
| 2925 | |||
| 2926 | // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) | ||
| 2927 | // il indicates where the q5 quants begin (0 or QK5_1/4) | ||
| 2928 | // we assume that the yl's have been multiplied with the appropriate scale factor | ||
| 2929 | // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) | ||
| 2930 | inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { | ||
| 2931 | float d = qb_curr->d; | ||
| 2932 | float m = qb_curr->m; | ||
| 2933 | |||
| 2934 | float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; | ||
| 2935 | |||
| 2936 | device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); | ||
| 2937 | const uint32_t qh = *((device const uint32_t *)qb_curr->qh); | ||
| 2938 | |||
| 2939 | for (int i = 0; i < 8; i+=2) { | ||
| 2940 | acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)); | ||
| 2941 | acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); | ||
| 2942 | acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)); | ||
| 2943 | acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); | ||
| 2944 | } | ||
| 2945 | |||
| 2946 | return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m; | ||
| 2947 | } | ||
| 2948 | |||
| 2949 | template<short NR0> | ||
| 2950 | static inline void helper_mv_reduce_and_write( | ||
| 2951 | device float * dst_f32, | ||
| 2952 | float sumf[NR0], | ||
| 2953 | const int r0, | ||
| 2954 | const int ne01, | ||
| 2955 | ushort tiisg, | ||
| 2956 | ushort sgitg, | ||
| 2957 | threadgroup char * shmem) { | ||
| 2958 | constexpr short NW = N_SIMDWIDTH; | ||
| 2959 | |||
| 2960 | threadgroup float * shmem_f32[NR0]; | ||
| 2961 | |||
| 2962 | for (short row = 0; row < NR0; ++row) { | ||
| 2963 | shmem_f32[row] = (threadgroup float *) shmem + NW*row; | ||
| 2964 | |||
| 2965 | if (sgitg == 0) { | ||
| 2966 | shmem_f32[row][tiisg] = 0.0f; | ||
| 2967 | } | ||
| 2968 | |||
| 2969 | sumf[row] = simd_sum(sumf[row]); | ||
| 2970 | } | ||
| 2971 | |||
| 2972 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2973 | |||
| 2974 | for (short row = 0; row < NR0; ++row) { | ||
| 2975 | if (tiisg == 0) { | ||
| 2976 | shmem_f32[row][sgitg] = sumf[row]; | ||
| 2977 | } | ||
| 2978 | } | ||
| 2979 | |||
| 2980 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 2981 | |||
| 2982 | for (short row = 0; row < NR0 && r0 + row < ne01; ++row) { | ||
| 2983 | float tot = simd_sum(shmem_f32[row][tiisg]); | ||
| 2984 | |||
| 2985 | if (tiisg == 0 && sgitg == 0) { | ||
| 2986 | dst_f32[r0 + row] = tot; | ||
| 2987 | } | ||
| 2988 | } | ||
| 2989 | } | ||
| 2990 | |||
| 2991 | constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; | ||
| 2992 | constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; | ||
| 2993 | |||
| 2994 | template<typename block_q_type, short NR0, typename args_t> | ||
| 2995 | void mul_vec_q_n_f32_impl( | ||
| 2996 | args_t args, | ||
| 2997 | device const char * src0, | ||
| 2998 | device const char * src1, | ||
| 2999 | device char * dst, | ||
| 3000 | threadgroup char * shmem, | ||
| 3001 | uint3 tgpig, | ||
| 3002 | ushort tiisg, | ||
| 3003 | ushort sgitg) { | ||
| 3004 | const short NSG = FC_mul_mv_nsg; | ||
| 3005 | |||
| 3006 | constexpr short NW = N_SIMDWIDTH; | ||
| 3007 | constexpr short NQ = 16; | ||
| 3008 | |||
| 3009 | const int nb = args.ne00/QK4_0; | ||
| 3010 | |||
| 3011 | const int r0 = (tgpig.x*NSG + sgitg)*NR0; | ||
| 3012 | //const int r0 = tgpig.x*NR0; | ||
| 3013 | const int r1 = tgpig.y; | ||
| 3014 | const int im = tgpig.z; | ||
| 3015 | |||
| 3016 | const uint i12 = im%args.ne12; | ||
| 3017 | const uint i13 = im/args.ne12; | ||
| 3018 | |||
| 3019 | //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3020 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 3021 | |||
| 3022 | //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); | ||
| 3023 | device const float * y = (device const float *) (src1 + offset1); | ||
| 3024 | |||
| 3025 | // pointers to src0 rows | ||
| 3026 | device const block_q_type * ax[NR0]; | ||
| 3027 | FOR_UNROLL (int row = 0; row < NR0; ++row) { | ||
| 3028 | const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3029 | |||
| 3030 | ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); | ||
| 3031 | } | ||
| 3032 | |||
| 3033 | float sumf[NR0] = {0.f}; | ||
| 3034 | |||
| 3035 | const short ix = (tiisg/(NW/NQ)); | ||
| 3036 | const short il = (tiisg%(NW/NQ))*8; | ||
| 3037 | |||
| 3038 | //const int ib0 = sgitg*NQ + ix; | ||
| 3039 | const int ib0 = ix; | ||
| 3040 | |||
| 3041 | float yl[16]; // src1 vector cache | ||
| 3042 | |||
| 3043 | //device const float * yb = y + ix*QK4_0 + il; | ||
| 3044 | device const float * yb = y + ib0*QK4_0 + il; | ||
| 3045 | |||
| 3046 | // each thread in a SIMD group deals with half a block. | ||
| 3047 | //for (int ib = ib0; ib < nb; ib += NSG*NQ) { | ||
| 3048 | for (int ib = ib0; ib < nb; ib += NQ) { | ||
| 3049 | float sumy[2] = { 0.f, 0.f }; | ||
| 3050 | |||
| 3051 | FOR_UNROLL (short i = 0; i < 8; i += 2) { | ||
| 3052 | sumy[0] += yb[i + 0] + yb[i + 1]; | ||
| 3053 | yl[i + 0] = yb[i + 0]; | ||
| 3054 | yl[i + 1] = yb[i + 1]/256.f; | ||
| 3055 | |||
| 3056 | sumy[1] += yb[i + 16] + yb[i + 17]; | ||
| 3057 | yl[i + 8] = yb[i + 16]/16.f; | ||
| 3058 | yl[i + 9] = yb[i + 17]/4096.f; | ||
| 3059 | } | ||
| 3060 | |||
| 3061 | FOR_UNROLL (short row = 0; row < NR0; row++) { | ||
| 3062 | sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); | ||
| 3063 | } | ||
| 3064 | |||
| 3065 | yb += QK4_0 * 16; | ||
| 3066 | //yb += NSG*NQ*QK4_0; | ||
| 3067 | } | ||
| 3068 | |||
| 3069 | device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; | ||
| 3070 | |||
| 3071 | //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); | ||
| 3072 | |||
| 3073 | for (int row = 0; row < NR0; ++row) { | ||
| 3074 | const float tot = simd_sum(sumf[row]); | ||
| 3075 | |||
| 3076 | if (tiisg == 0 && r0 + row < args.ne01) { | ||
| 3077 | dst_f32[r0 + row] = tot; | ||
| 3078 | } | ||
| 3079 | } | ||
| 3080 | } | ||
| 3081 | |||
| 3082 | kernel void kernel_mul_mv_q4_0_f32( | ||
| 3083 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3084 | device const char * src0, | ||
| 3085 | device const char * src1, | ||
| 3086 | device char * dst, | ||
| 3087 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 3088 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3089 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3090 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3091 | mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 3092 | } | ||
| 3093 | |||
| 3094 | kernel void kernel_mul_mv_q4_1_f32( | ||
| 3095 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3096 | device const char * src0, | ||
| 3097 | device const char * src1, | ||
| 3098 | device char * dst, | ||
| 3099 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 3100 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3101 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3102 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3103 | mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 3104 | } | ||
| 3105 | |||
| 3106 | kernel void kernel_mul_mv_q5_0_f32( | ||
| 3107 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3108 | device const char * src0, | ||
| 3109 | device const char * src1, | ||
| 3110 | device char * dst, | ||
| 3111 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 3112 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3113 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3114 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3115 | mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 3116 | } | ||
| 3117 | |||
| 3118 | kernel void kernel_mul_mv_q5_1_f32( | ||
| 3119 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3120 | device const char * src0, | ||
| 3121 | device const char * src1, | ||
| 3122 | device char * dst, | ||
| 3123 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 3124 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3125 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3126 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3127 | mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 3128 | } | ||
| 3129 | |||
| 3130 | template<short NR0, typename args_t> | ||
| 3131 | void kernel_mul_mv_q8_0_f32_impl( | ||
| 3132 | args_t args, | ||
| 3133 | device const char * src0, | ||
| 3134 | device const char * src1, | ||
| 3135 | device char * dst, | ||
| 3136 | threadgroup char * shmem, | ||
| 3137 | uint3 tgpig, | ||
| 3138 | ushort tiisg, | ||
| 3139 | ushort sgitg) { | ||
| 3140 | const short NSG = FC_mul_mv_nsg; | ||
| 3141 | |||
| 3142 | constexpr short NW = N_SIMDWIDTH; | ||
| 3143 | constexpr short NQ = 8; | ||
| 3144 | |||
| 3145 | const int nb = args.ne00/QK8_0; | ||
| 3146 | |||
| 3147 | const int r0 = tgpig.x*NR0; | ||
| 3148 | const int r1 = tgpig.y; | ||
| 3149 | const int im = tgpig.z; | ||
| 3150 | |||
| 3151 | const uint i12 = im%args.ne12; | ||
| 3152 | const uint i13 = im/args.ne12; | ||
| 3153 | |||
| 3154 | //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3155 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 3156 | |||
| 3157 | //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); | ||
| 3158 | device const float * y = (device const float *) (src1 + offset1); | ||
| 3159 | |||
| 3160 | // pointers to src0 rows | ||
| 3161 | device const block_q8_0 * ax[NR0]; | ||
| 3162 | FOR_UNROLL (short row = 0; row < NR0; ++row) { | ||
| 3163 | const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3164 | |||
| 3165 | ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); | ||
| 3166 | } | ||
| 3167 | |||
| 3168 | float sumf[NR0] = { 0.f }; | ||
| 3169 | |||
| 3170 | const short ix = tiisg/(NW/NQ); | ||
| 3171 | const short il = tiisg%(NW/NQ); | ||
| 3172 | |||
| 3173 | const int ib0 = sgitg*NQ + ix; | ||
| 3174 | |||
| 3175 | float yl[NQ]; | ||
| 3176 | |||
| 3177 | device const float * yb = y + ib0*QK8_0 + il*NQ; | ||
| 3178 | |||
| 3179 | // each thread in a SIMD group deals with NQ quants at a time | ||
| 3180 | for (int ib = ib0; ib < nb; ib += NSG*NQ) { | ||
| 3181 | for (short i = 0; i < NQ; ++i) { | ||
| 3182 | yl[i] = yb[i]; | ||
| 3183 | } | ||
| 3184 | |||
| 3185 | for (short row = 0; row < NR0; row++) { | ||
| 3186 | device const int8_t * qs = ax[row][ib].qs + il*NQ; | ||
| 3187 | |||
| 3188 | float sumq = 0.f; | ||
| 3189 | FOR_UNROLL (short i = 0; i < NQ; ++i) { | ||
| 3190 | sumq += qs[i] * yl[i]; | ||
| 3191 | } | ||
| 3192 | |||
| 3193 | sumf[row] += sumq*ax[row][ib].d; | ||
| 3194 | } | ||
| 3195 | |||
| 3196 | yb += NSG*NQ*QK8_0; | ||
| 3197 | } | ||
| 3198 | |||
| 3199 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 3200 | |||
| 3201 | helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); | ||
| 3202 | } | ||
| 3203 | |||
| 3204 | [[host_name("kernel_mul_mv_q8_0_f32")]] | ||
| 3205 | kernel void kernel_mul_mv_q8_0_f32( | ||
| 3206 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3207 | device const char * src0, | ||
| 3208 | device const char * src1, | ||
| 3209 | device char * dst, | ||
| 3210 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 3211 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3212 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3213 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3214 | kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 3215 | } | ||
| 3216 | |||
| 3217 | // mat-vec kernel processing in chunks of float4 | ||
| 3218 | // chpb - chunks per quantization block | ||
| 3219 | template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) > | ||
| 3220 | void kernel_mul_mv_ext_q4_f32_impl( | ||
| 3221 | constant ggml_metal_kargs_mul_mv_ext & args, | ||
| 3222 | device const char * src0, | ||
| 3223 | device const char * src1, | ||
| 3224 | device char * dst, | ||
| 3225 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3226 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3227 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3228 | const short NSG = FC_mul_mv_nsg; | ||
| 3229 | const short nxpsg = FC_mul_mv_nxpsg; | ||
| 3230 | |||
| 3231 | const short chpt = 4; // chunks per thread | ||
| 3232 | |||
| 3233 | //const short nxpsg = (32); | ||
| 3234 | const short nypsg = (32/nxpsg); | ||
| 3235 | |||
| 3236 | const short tx = tiisg%nxpsg; | ||
| 3237 | const short ty = tiisg/nxpsg; | ||
| 3238 | |||
| 3239 | const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; | ||
| 3240 | const int i11 = tgpig.y*r1ptg; | ||
| 3241 | const int i1m = tgpig.z; | ||
| 3242 | |||
| 3243 | const int i12 = i1m%args.ne12; | ||
| 3244 | const int i13 = i1m/args.ne12; | ||
| 3245 | |||
| 3246 | const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3247 | const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 3248 | |||
| 3249 | device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; | ||
| 3250 | |||
| 3251 | device const float4 * y4[r1ptg]; | ||
| 3252 | |||
| 3253 | for (int ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3254 | y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1; | ||
| 3255 | } | ||
| 3256 | |||
| 3257 | float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; | ||
| 3258 | |||
| 3259 | short cch = tx%chpb; // current chunk index | ||
| 3260 | |||
| 3261 | for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) { | ||
| 3262 | float4 lx[chpt]; | ||
| 3263 | |||
| 3264 | #pragma unroll(chpt) | ||
| 3265 | for (short ch = 0; ch < chpt; ++ch) { | ||
| 3266 | deq_t4(xq, cch, lx[ch]); | ||
| 3267 | |||
| 3268 | cch += nxpsg; | ||
| 3269 | if (cch >= chpb) { | ||
| 3270 | xq += cch/chpb; | ||
| 3271 | cch %= chpb; | ||
| 3272 | } | ||
| 3273 | } | ||
| 3274 | |||
| 3275 | #pragma unroll(chpt) | ||
| 3276 | for (short ch = 0; ch < chpt; ++ch) { | ||
| 3277 | #pragma unroll(r1ptg) | ||
| 3278 | for (short ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3279 | sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]); | ||
| 3280 | } | ||
| 3281 | } | ||
| 3282 | |||
| 3283 | #pragma unroll(r1ptg) | ||
| 3284 | for (short ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3285 | y4[ir1] += chpt*nxpsg; | ||
| 3286 | } | ||
| 3287 | } | ||
| 3288 | |||
| 3289 | // reduce only the threads in each row | ||
| 3290 | for (short ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3291 | if (nxpsg >= 32) { | ||
| 3292 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); | ||
| 3293 | } | ||
| 3294 | if (nxpsg >= 16) { | ||
| 3295 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); | ||
| 3296 | } | ||
| 3297 | if (nxpsg >= 8) { | ||
| 3298 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); | ||
| 3299 | } | ||
| 3300 | if (nxpsg >= 4) { | ||
| 3301 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); | ||
| 3302 | } | ||
| 3303 | if (nxpsg >= 2) { | ||
| 3304 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); | ||
| 3305 | } | ||
| 3306 | |||
| 3307 | //sumf[ir1] = simd_sum(sumf[ir1]); | ||
| 3308 | } | ||
| 3309 | |||
| 3310 | if (tx == 0) { | ||
| 3311 | for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { | ||
| 3312 | device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; | ||
| 3313 | |||
| 3314 | if (i01 < args.ne01) { | ||
| 3315 | dst_f32[i01] = sumf[ir1]; | ||
| 3316 | } | ||
| 3317 | } | ||
| 3318 | } | ||
| 3319 | } | ||
| 3320 | |||
| 3321 | // mat-vec kernel processing in chunks of float4x4 | ||
| 3322 | template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) > | ||
| 3323 | void kernel_mul_mv_ext_q4x4_f32_impl( | ||
| 3324 | constant ggml_metal_kargs_mul_mv_ext & args, | ||
| 3325 | device const char * src0, | ||
| 3326 | device const char * src1, | ||
| 3327 | device char * dst, | ||
| 3328 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3329 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3330 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3331 | const short NSG = FC_mul_mv_nsg; | ||
| 3332 | const short nxpsg = FC_mul_mv_nxpsg; | ||
| 3333 | |||
| 3334 | const short chpt = 1; | ||
| 3335 | |||
| 3336 | //const short nxpsg = (32); | ||
| 3337 | const short nypsg = (32/nxpsg); | ||
| 3338 | |||
| 3339 | const short tx = tiisg%nxpsg; | ||
| 3340 | const short ty = tiisg/nxpsg; | ||
| 3341 | |||
| 3342 | const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty; | ||
| 3343 | const int i11 = tgpig.y*r1ptg; | ||
| 3344 | const int i1m = tgpig.z; | ||
| 3345 | |||
| 3346 | const int i12 = i1m%args.ne12; | ||
| 3347 | const int i13 = i1m/args.ne12; | ||
| 3348 | |||
| 3349 | const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3350 | const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 3351 | |||
| 3352 | device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; | ||
| 3353 | |||
| 3354 | device const float4x4 * y4x4[r1ptg]; | ||
| 3355 | |||
| 3356 | for (int ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3357 | y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1; | ||
| 3358 | } | ||
| 3359 | |||
| 3360 | float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f }; | ||
| 3361 | |||
| 3362 | short cch = tx%chpb; | ||
| 3363 | |||
| 3364 | for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) { | ||
| 3365 | float4x4 lx[chpt]; | ||
| 3366 | |||
| 3367 | #pragma unroll(chpt) | ||
| 3368 | for (short ch = 0; ch < chpt; ++ch) { | ||
| 3369 | deq_t4x4(xq, cch, lx[ch]); | ||
| 3370 | |||
| 3371 | cch += nxpsg; | ||
| 3372 | if (cch >= chpb) { | ||
| 3373 | xq += cch/chpb; | ||
| 3374 | cch %= chpb; | ||
| 3375 | } | ||
| 3376 | } | ||
| 3377 | |||
| 3378 | #pragma unroll(chpt) | ||
| 3379 | for (short ch = 0; ch < chpt; ++ch) { | ||
| 3380 | #pragma unroll(r1ptg) | ||
| 3381 | for (short ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3382 | sumf[ir1] += | ||
| 3383 | dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) + | ||
| 3384 | dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) + | ||
| 3385 | dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) + | ||
| 3386 | dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]); | ||
| 3387 | |||
| 3388 | } | ||
| 3389 | } | ||
| 3390 | |||
| 3391 | #pragma unroll(r1ptg) | ||
| 3392 | for (short ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3393 | y4x4[ir1] += chpt*nxpsg; | ||
| 3394 | } | ||
| 3395 | } | ||
| 3396 | |||
| 3397 | for (short ir1 = 0; ir1 < r1ptg; ++ir1) { | ||
| 3398 | if (nxpsg >= 32) { | ||
| 3399 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 16); | ||
| 3400 | } | ||
| 3401 | if (nxpsg >= 16) { | ||
| 3402 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 8); | ||
| 3403 | } | ||
| 3404 | if (nxpsg >= 8) { | ||
| 3405 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 4); | ||
| 3406 | } | ||
| 3407 | if (nxpsg >= 4) { | ||
| 3408 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 2); | ||
| 3409 | } | ||
| 3410 | if (nxpsg >= 2) { | ||
| 3411 | sumf[ir1] += simd_shuffle_down(sumf[ir1], 1); | ||
| 3412 | } | ||
| 3413 | |||
| 3414 | //sumf[ir1] = simd_sum(sumf[ir1]); | ||
| 3415 | } | ||
| 3416 | |||
| 3417 | if (tx == 0) { | ||
| 3418 | for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) { | ||
| 3419 | device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; | ||
| 3420 | |||
| 3421 | if (i01 < args.ne01) { | ||
| 3422 | dst_f32[i01] = sumf[ir1]; | ||
| 3423 | } | ||
| 3424 | } | ||
| 3425 | } | ||
| 3426 | } | ||
| 3427 | |||
| 3428 | // dispatchers needed for compile-time nxpsg | ||
| 3429 | // epb - elements per quantization block | ||
| 3430 | template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)> | ||
| 3431 | kernel void kernel_mul_mv_ext_q4_f32_disp( | ||
| 3432 | constant ggml_metal_kargs_mul_mv_ext & args, | ||
| 3433 | device const char * src0, | ||
| 3434 | device const char * src1, | ||
| 3435 | device char * dst, | ||
| 3436 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3437 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3438 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3439 | kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); | ||
| 3440 | } | ||
| 3441 | |||
| 3442 | template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)> | ||
| 3443 | kernel void kernel_mul_mv_ext_q4x4_f32_disp( | ||
| 3444 | constant ggml_metal_kargs_mul_mv_ext & args, | ||
| 3445 | device const char * src0, | ||
| 3446 | device const char * src1, | ||
| 3447 | device char * dst, | ||
| 3448 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3449 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3450 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3451 | kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); | ||
| 3452 | } | ||
| 3453 | |||
| 3454 | typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t; | ||
| 3455 | typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t; | ||
| 3456 | |||
| 3457 | template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>; | ||
| 3458 | template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>; | ||
| 3459 | template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>; | ||
| 3460 | template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>; | ||
| 3461 | |||
| 3462 | template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>; | ||
| 3463 | template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>; | ||
| 3464 | template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; | ||
| 3465 | template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; | ||
| 3466 | |||
| 3467 | template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; | ||
| 3468 | template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; | ||
| 3469 | template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; | ||
| 3470 | template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0, 32, dequantize_q4_0_t4>; | ||
| 3471 | |||
| 3472 | template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1, 32, dequantize_q4_1_t4>; | ||
| 3473 | template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1, 32, dequantize_q4_1_t4>; | ||
| 3474 | template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1, 32, dequantize_q4_1_t4>; | ||
| 3475 | template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1, 32, dequantize_q4_1_t4>; | ||
| 3476 | |||
| 3477 | template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0, 32, dequantize_q5_0_t4>; | ||
| 3478 | template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0, 32, dequantize_q5_0_t4>; | ||
| 3479 | template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0, 32, dequantize_q5_0_t4>; | ||
| 3480 | template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0, 32, dequantize_q5_0_t4>; | ||
| 3481 | |||
| 3482 | template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1, 32, dequantize_q5_1_t4>; | ||
| 3483 | template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1, 32, dequantize_q5_1_t4>; | ||
| 3484 | template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1, 32, dequantize_q5_1_t4>; | ||
| 3485 | template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1, 32, dequantize_q5_1_t4>; | ||
| 3486 | |||
| 3487 | template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>; | ||
| 3488 | template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0, 32, dequantize_q8_0_t4>; | ||
| 3489 | template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>; | ||
| 3490 | template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>; | ||
| 3491 | |||
| 3492 | template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>; | ||
| 3493 | template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>; | ||
| 3494 | template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>; | ||
| 3495 | template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>; | ||
| 3496 | |||
| 3497 | template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>; | ||
| 3498 | template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>; | ||
| 3499 | template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>; | ||
| 3500 | template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>; | ||
| 3501 | |||
| 3502 | template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>; | ||
| 3503 | template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>; | ||
| 3504 | template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>; | ||
| 3505 | template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>; | ||
| 3506 | |||
| 3507 | template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>; | ||
| 3508 | template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>; | ||
| 3509 | template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>; | ||
| 3510 | template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>; | ||
| 3511 | |||
| 3512 | template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>; | ||
| 3513 | template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>; | ||
| 3514 | template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; | ||
| 3515 | template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; | ||
| 3516 | |||
| 3517 | template<typename T0, typename T1, short NR0, typename args_t> | ||
| 3518 | void kernel_mul_mv_t_t_impl( | ||
| 3519 | args_t args, | ||
| 3520 | device const char * src0, | ||
| 3521 | device const char * src1, | ||
| 3522 | device char * dst, | ||
| 3523 | threadgroup char * shmem, | ||
| 3524 | uint3 tgpig, | ||
| 3525 | ushort tiisg, | ||
| 3526 | ushort sgitg) { | ||
| 3527 | const short NSG = FC_mul_mv_nsg; | ||
| 3528 | |||
| 3529 | constexpr short NW = N_SIMDWIDTH; | ||
| 3530 | constexpr short NB = 32; | ||
| 3531 | constexpr short NF = 8; | ||
| 3532 | |||
| 3533 | const int nb = args.ne00/NB; | ||
| 3534 | |||
| 3535 | const int r0 = tgpig.x*NR0; | ||
| 3536 | const int r1 = tgpig.y; | ||
| 3537 | const int im = tgpig.z; | ||
| 3538 | |||
| 3539 | const uint i12 = im%args.ne12; | ||
| 3540 | const uint i13 = im/args.ne12; | ||
| 3541 | |||
| 3542 | //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3543 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 3544 | |||
| 3545 | //device const T0 * x = (device const T0 *) (src0 + offset0); | ||
| 3546 | device const T1 * y = (device const T1 *) (src1 + offset1); | ||
| 3547 | |||
| 3548 | // pointers to src0 rows | ||
| 3549 | device const T0 * ax [NR0]; | ||
| 3550 | FOR_UNROLL (short row = 0; row < NR0; ++row) { | ||
| 3551 | const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3552 | |||
| 3553 | ax[row] = (device const T0 *) ((device char *) src0 + offset0); | ||
| 3554 | } | ||
| 3555 | |||
| 3556 | float sumf[NR0] = { 0.f }; | ||
| 3557 | |||
| 3558 | const short ix = tiisg/(NW/NF); | ||
| 3559 | const short il = tiisg%(NW/NF); | ||
| 3560 | |||
| 3561 | const int ib0 = sgitg*NF + ix; | ||
| 3562 | |||
| 3563 | T1 yl[NF]; | ||
| 3564 | |||
| 3565 | device const T1 * yb = y + (ib0*NB + il*NF); | ||
| 3566 | |||
| 3567 | for (int ib = ib0; ib < nb; ib += NSG*NF) { | ||
| 3568 | for (short i = 0; i < NF; ++i) { | ||
| 3569 | yl[i] = yb[i]; | ||
| 3570 | } | ||
| 3571 | |||
| 3572 | for (short row = 0; row < NR0; row++) { | ||
| 3573 | device const T0 * xb = ax[row] + (ib*NB + il*NF); | ||
| 3574 | |||
| 3575 | float sumq = 0.f; | ||
| 3576 | FOR_UNROLL (short i = 0; i < NF; ++i) { | ||
| 3577 | sumq += xb[i] * yl[i]; | ||
| 3578 | } | ||
| 3579 | |||
| 3580 | sumf[row] += sumq; | ||
| 3581 | } | ||
| 3582 | |||
| 3583 | yb += NSG*NF*NW; | ||
| 3584 | } | ||
| 3585 | |||
| 3586 | for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { | ||
| 3587 | for (short row = 0; row < NR0; row++) { | ||
| 3588 | sumf[row] += ax[row][i] * y[i]; | ||
| 3589 | } | ||
| 3590 | } | ||
| 3591 | |||
| 3592 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 3593 | |||
| 3594 | helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); | ||
| 3595 | } | ||
| 3596 | |||
| 3597 | template<typename T0, typename T1, typename args_t> | ||
| 3598 | void kernel_mul_mv_t_t_disp( | ||
| 3599 | args_t args, | ||
| 3600 | device const char * src0, | ||
| 3601 | device const char * src1, | ||
| 3602 | device char * dst, | ||
| 3603 | threadgroup char * shmem, | ||
| 3604 | uint3 tgpig, | ||
| 3605 | ushort tiisg, | ||
| 3606 | ushort sgitg) { | ||
| 3607 | switch (args.nr0) { | ||
| 3608 | //case 1: kernel_mul_mv_t_t_impl<T0, T1, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3609 | case 2: kernel_mul_mv_t_t_impl<T0, T1, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3610 | //case 3: kernel_mul_mv_t_t_impl<T0, T1, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3611 | //case 4: kernel_mul_mv_t_t_impl<T0, T1, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3612 | } | ||
| 3613 | } | ||
| 3614 | |||
| 3615 | template<typename T0, typename T1> | ||
| 3616 | kernel void kernel_mul_mv_t_t( | ||
| 3617 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3618 | device const char * src0, | ||
| 3619 | device const char * src1, | ||
| 3620 | device char * dst, | ||
| 3621 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 3622 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3623 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3624 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3625 | kernel_mul_mv_t_t_disp<T0, T1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 3626 | } | ||
| 3627 | |||
| 3628 | typedef decltype(kernel_mul_mv_t_t<half, half>) mul_mv_t_t; | ||
| 3629 | |||
| 3630 | template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float>; | ||
| 3631 | template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float>; | ||
| 3632 | template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half>; | ||
| 3633 | #if defined(GGML_METAL_HAS_BF16) | ||
| 3634 | template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float>; | ||
| 3635 | template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat>; | ||
| 3636 | #endif | ||
| 3637 | |||
| 3638 | template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t> | ||
| 3639 | void kernel_mul_mv_t_t_4_impl( | ||
| 3640 | args_t args, | ||
| 3641 | device const char * src0, | ||
| 3642 | device const char * src1, | ||
| 3643 | device char * dst, | ||
| 3644 | threadgroup char * shmem, | ||
| 3645 | uint3 tgpig, | ||
| 3646 | ushort tiisg, | ||
| 3647 | ushort sgitg) { | ||
| 3648 | const short NSG = FC_mul_mv_nsg; | ||
| 3649 | |||
| 3650 | constexpr short NW = N_SIMDWIDTH; | ||
| 3651 | constexpr short NB = 32; | ||
| 3652 | constexpr short NF = 16; | ||
| 3653 | constexpr short NF4 = NF/4; | ||
| 3654 | |||
| 3655 | const int nb = args.ne00/NB; | ||
| 3656 | |||
| 3657 | const int r0 = tgpig.x*NR0; | ||
| 3658 | const int r1 = tgpig.y; | ||
| 3659 | const int im = tgpig.z; | ||
| 3660 | |||
| 3661 | const uint i12 = im%args.ne12; | ||
| 3662 | const uint i13 = im/args.ne12; | ||
| 3663 | |||
| 3664 | //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3665 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 3666 | |||
| 3667 | device const T1 * y = (device const T1 *) (src1 + offset1); | ||
| 3668 | device const T14 * y4 = (device const T14 *) (src1 + offset1); | ||
| 3669 | |||
| 3670 | // pointers to src0 rows | ||
| 3671 | device const T0 * ax [NR0]; | ||
| 3672 | device const T04 * ax4[NR0]; | ||
| 3673 | FOR_UNROLL (short row = 0; row < NR0; ++row) { | ||
| 3674 | const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3675 | |||
| 3676 | ax [row] = (device const T0 *) ((device char *) src0 + offset0); | ||
| 3677 | ax4[row] = (device const T04 *) ((device char *) src0 + offset0); | ||
| 3678 | } | ||
| 3679 | |||
| 3680 | float sumf[NR0] = { 0.f }; | ||
| 3681 | |||
| 3682 | const short ix = tiisg/(NW/NF); | ||
| 3683 | const short il = tiisg%(NW/NF); | ||
| 3684 | |||
| 3685 | const int ib0 = sgitg*NF + ix; | ||
| 3686 | |||
| 3687 | T14 yl4[NF4]; | ||
| 3688 | |||
| 3689 | device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4; | ||
| 3690 | |||
| 3691 | for (int ib = ib0; ib < nb; ib += NSG*NF) { | ||
| 3692 | for (short i = 0; i < NF4; ++i) { | ||
| 3693 | yl4[i] = yb4[i]; | ||
| 3694 | } | ||
| 3695 | |||
| 3696 | for (short row = 0; row < NR0; row++) { | ||
| 3697 | device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4; | ||
| 3698 | |||
| 3699 | float sumq = 0.f; | ||
| 3700 | FOR_UNROLL (short i = 0; i < NF4; ++i) { | ||
| 3701 | sumq += dot(float4(xb4[i]), float4(yl4[i])); | ||
| 3702 | } | ||
| 3703 | |||
| 3704 | sumf[row] += sumq; | ||
| 3705 | } | ||
| 3706 | |||
| 3707 | yb4 += NSG*NF*NW/4; | ||
| 3708 | } | ||
| 3709 | |||
| 3710 | for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { | ||
| 3711 | for (short row = 0; row < NR0; row++) { | ||
| 3712 | sumf[row] += ax[row][i] * y[i]; | ||
| 3713 | } | ||
| 3714 | } | ||
| 3715 | |||
| 3716 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 3717 | |||
| 3718 | helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); | ||
| 3719 | } | ||
| 3720 | |||
| 3721 | template<typename T0, typename T04, typename T1, typename T14, typename args_t> | ||
| 3722 | void kernel_mul_mv_t_t_4_disp( | ||
| 3723 | args_t args, | ||
| 3724 | device const char * src0, | ||
| 3725 | device const char * src1, | ||
| 3726 | device char * dst, | ||
| 3727 | threadgroup char * shmem, | ||
| 3728 | uint3 tgpig, | ||
| 3729 | ushort tiisg, | ||
| 3730 | ushort sgitg) { | ||
| 3731 | switch (args.nr0) { | ||
| 3732 | //case 1: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 1, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3733 | case 2: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 2, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3734 | //case 3: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 3, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3735 | //case 4: kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, 4, args_t>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; | ||
| 3736 | }; | ||
| 3737 | } | ||
| 3738 | |||
| 3739 | template<typename T0, typename T04, typename T1, typename T14> | ||
| 3740 | kernel void kernel_mul_mv_t_t_4( | ||
| 3741 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3742 | device const char * src0, | ||
| 3743 | device const char * src1, | ||
| 3744 | device char * dst, | ||
| 3745 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 3746 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3747 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 3748 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 3749 | kernel_mul_mv_t_t_4_disp<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 3750 | } | ||
| 3751 | |||
| 3752 | typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4>) mul_mv_t_t_4; | ||
| 3753 | |||
| 3754 | template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4>; | ||
| 3755 | template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4>; | ||
| 3756 | template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4>; | ||
| 3757 | #if defined(GGML_METAL_HAS_BF16) | ||
| 3758 | template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4>; | ||
| 3759 | template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4>; | ||
| 3760 | #endif | ||
| 3761 | |||
| 3762 | template<typename T0, typename T1, typename args_t> | ||
| 3763 | void kernel_mul_mv_t_t_short_impl( | ||
| 3764 | args_t args, | ||
| 3765 | device const char * src0, | ||
| 3766 | device const char * src1, | ||
| 3767 | device char * dst, | ||
| 3768 | uint3 tgpig, | ||
| 3769 | ushort tiisg) { | ||
| 3770 | const int r0 = tgpig.x*32 + tiisg; | ||
| 3771 | const int r1 = tgpig.y; | ||
| 3772 | const int im = tgpig.z; | ||
| 3773 | |||
| 3774 | if (r0 >= args.ne01) { | ||
| 3775 | return; | ||
| 3776 | } | ||
| 3777 | |||
| 3778 | const uint i12 = im%args.ne12; | ||
| 3779 | const uint i13 = im/args.ne12; | ||
| 3780 | |||
| 3781 | const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 3782 | |||
| 3783 | device const T0 * x = (device const T0 *) (src0 + offset0); | ||
| 3784 | |||
| 3785 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; | ||
| 3786 | |||
| 3787 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 3788 | |||
| 3789 | device const T1 * y = (device const T1 *) (src1 + offset1); | ||
| 3790 | |||
| 3791 | float res = 0.0f; | ||
| 3792 | |||
| 3793 | for (int i = 0; i < args.ne00; ++i) { | ||
| 3794 | res += (float) x[i] * (float) y[i]; | ||
| 3795 | } | ||
| 3796 | |||
| 3797 | dst_f32[(uint64_t)r1*args.ne0 + r0] = res; | ||
| 3798 | } | ||
| 3799 | |||
| 3800 | template<typename T0, typename T1> | ||
| 3801 | kernel void kernel_mul_mv_t_t_short( | ||
| 3802 | constant ggml_metal_kargs_mul_mv & args, | ||
| 3803 | device const char * src0, | ||
| 3804 | device const char * src1, | ||
| 3805 | device char * dst, | ||
| 3806 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 3807 | ushort tiisg[[thread_index_in_simdgroup]]) { | ||
| 3808 | kernel_mul_mv_t_t_short_impl<T0, T1, constant ggml_metal_kargs_mul_mv &>( | ||
| 3809 | args, | ||
| 3810 | src0, | ||
| 3811 | src1, | ||
| 3812 | dst, | ||
| 3813 | tgpig, | ||
| 3814 | tiisg); | ||
| 3815 | } | ||
| 3816 | |||
| 3817 | typedef decltype(kernel_mul_mv_t_t_short<half, half>) mul_mv_t_t_short_t; | ||
| 3818 | |||
| 3819 | template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<float, float>; | ||
| 3820 | template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, float>; | ||
| 3821 | template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<half, half>; | ||
| 3822 | #if defined(GGML_METAL_HAS_BF16) | ||
| 3823 | template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, float>; | ||
| 3824 | template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short<bfloat, bfloat>; | ||
| 3825 | #endif | ||
| 3826 | |||
| 3827 | constant bool FC_rope_is_imrope [[function_constant(FC_ROPE + 0)]]; | ||
| 3828 | |||
| 3829 | static float rope_yarn_ramp(const float low, const float high, const int i0) { | ||
| 3830 | const float y = (i0 / 2 - low) / max(0.001f, high - low); | ||
| 3831 | return 1.0f - min(1.0f, max(0.0f, y)); | ||
| 3832 | } | ||
| 3833 | |||
| 3834 | // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn | ||
| 3835 | // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. | ||
| 3836 | static void rope_yarn( | ||
| 3837 | float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, | ||
| 3838 | thread float * cos_theta, thread float * sin_theta) { | ||
| 3839 | // Get n-d rotational scaling corrected for extrapolation | ||
| 3840 | float theta_interp = freq_scale * theta_extrap; | ||
| 3841 | float theta = theta_interp; | ||
| 3842 | if (ext_factor != 0.0f) { | ||
| 3843 | float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; | ||
| 3844 | theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; | ||
| 3845 | |||
| 3846 | // Get n-d magnitude scaling corrected for interpolation | ||
| 3847 | mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); | ||
| 3848 | } | ||
| 3849 | *cos_theta = cos(theta) * mscale; | ||
| 3850 | *sin_theta = sin(theta) * mscale; | ||
| 3851 | } | ||
| 3852 | |||
| 3853 | // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get | ||
| 3854 | // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` | ||
| 3855 | static float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) { | ||
| 3856 | return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base)); | ||
| 3857 | } | ||
| 3858 | |||
| 3859 | static void rope_yarn_corr_dims( | ||
| 3860 | int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2] | ||
| 3861 | ) { | ||
| 3862 | // start and end correction dims | ||
| 3863 | dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))); | ||
| 3864 | dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); | ||
| 3865 | } | ||
| 3866 | |||
| 3867 | template<typename T> | ||
| 3868 | kernel void kernel_rope_norm( | ||
| 3869 | constant ggml_metal_kargs_rope & args, | ||
| 3870 | device const char * src0, | ||
| 3871 | device const char * src1, | ||
| 3872 | device const char * src2, | ||
| 3873 | device char * dst, | ||
| 3874 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 3875 | ushort3 tptg [[threads_per_threadgroup]], | ||
| 3876 | uint3 tgpig[[threadgroup_position_in_grid]]) { | ||
| 3877 | const int i3 = tgpig[2]; | ||
| 3878 | const int i2 = tgpig[1]; | ||
| 3879 | const int i1 = tgpig[0]; | ||
| 3880 | |||
| 3881 | float corr_dims[2]; | ||
| 3882 | rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); | ||
| 3883 | |||
| 3884 | device const int32_t * pos = (device const int32_t *) src1; | ||
| 3885 | |||
| 3886 | const float theta_base = (float) pos[i2]; | ||
| 3887 | const float inv_ndims = -1.f/args.n_dims; | ||
| 3888 | |||
| 3889 | float cos_theta; | ||
| 3890 | float sin_theta; | ||
| 3891 | |||
| 3892 | for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { | ||
| 3893 | if (i0 < args.n_dims) { | ||
| 3894 | const int ic = i0/2; | ||
| 3895 | |||
| 3896 | const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); | ||
| 3897 | |||
| 3898 | const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; | ||
| 3899 | |||
| 3900 | rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); | ||
| 3901 | |||
| 3902 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); | ||
| 3903 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 3904 | |||
| 3905 | const float x0 = src[0]; | ||
| 3906 | const float x1 = src[1]; | ||
| 3907 | |||
| 3908 | dst_data[0] = x0*cos_theta - x1*sin_theta; | ||
| 3909 | dst_data[1] = x0*sin_theta + x1*cos_theta; | ||
| 3910 | } else { | ||
| 3911 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); | ||
| 3912 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 3913 | |||
| 3914 | dst_data[0] = src[0]; | ||
| 3915 | dst_data[1] = src[1]; | ||
| 3916 | } | ||
| 3917 | } | ||
| 3918 | } | ||
| 3919 | |||
| 3920 | template<typename T> | ||
| 3921 | kernel void kernel_rope_neox( | ||
| 3922 | constant ggml_metal_kargs_rope & args, | ||
| 3923 | device const char * src0, | ||
| 3924 | device const char * src1, | ||
| 3925 | device const char * src2, | ||
| 3926 | device char * dst, | ||
| 3927 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 3928 | ushort3 tptg [[threads_per_threadgroup]], | ||
| 3929 | uint3 tgpig[[threadgroup_position_in_grid]]) { | ||
| 3930 | const int i3 = tgpig[2]; | ||
| 3931 | const int i2 = tgpig[1]; | ||
| 3932 | const int i1 = tgpig[0]; | ||
| 3933 | |||
| 3934 | float corr_dims[2]; | ||
| 3935 | rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); | ||
| 3936 | |||
| 3937 | device const int32_t * pos = (device const int32_t *) src1; | ||
| 3938 | |||
| 3939 | const float theta_base = (float) pos[i2]; | ||
| 3940 | const float inv_ndims = -1.f/args.n_dims; | ||
| 3941 | |||
| 3942 | float cos_theta; | ||
| 3943 | float sin_theta; | ||
| 3944 | |||
| 3945 | for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { | ||
| 3946 | if (i0 < args.n_dims) { | ||
| 3947 | const int ic = i0/2; | ||
| 3948 | |||
| 3949 | const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); | ||
| 3950 | |||
| 3951 | const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; | ||
| 3952 | |||
| 3953 | rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); | ||
| 3954 | |||
| 3955 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); | ||
| 3956 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); | ||
| 3957 | |||
| 3958 | const float x0 = src[0]; | ||
| 3959 | const float x1 = src[args.n_dims/2]; | ||
| 3960 | |||
| 3961 | dst_data[0] = x0*cos_theta - x1*sin_theta; | ||
| 3962 | dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; | ||
| 3963 | } else { | ||
| 3964 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); | ||
| 3965 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 3966 | |||
| 3967 | dst_data[0] = src[0]; | ||
| 3968 | dst_data[1] = src[1]; | ||
| 3969 | } | ||
| 3970 | } | ||
| 3971 | } | ||
| 3972 | |||
| 3973 | template<typename T> | ||
| 3974 | kernel void kernel_rope_multi( | ||
| 3975 | constant ggml_metal_kargs_rope & args, | ||
| 3976 | device const char * src0, | ||
| 3977 | device const char * src1, | ||
| 3978 | device const char * src2, | ||
| 3979 | device char * dst, | ||
| 3980 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 3981 | ushort3 tptg [[threads_per_threadgroup]], | ||
| 3982 | uint3 tgpig[[threadgroup_position_in_grid]]) { | ||
| 3983 | const int i3 = tgpig[2]; | ||
| 3984 | const int i2 = tgpig[1]; | ||
| 3985 | const int i1 = tgpig[0]; | ||
| 3986 | |||
| 3987 | float corr_dims[2]; | ||
| 3988 | rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); | ||
| 3989 | |||
| 3990 | device const int32_t * pos = (device const int32_t *) src1; | ||
| 3991 | |||
| 3992 | const float inv_ndims = -1.f/args.n_dims; | ||
| 3993 | |||
| 3994 | float cos_theta; | ||
| 3995 | float sin_theta; | ||
| 3996 | |||
| 3997 | for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { | ||
| 3998 | if (i0 < args.n_dims) { | ||
| 3999 | const int ic = i0/2; | ||
| 4000 | |||
| 4001 | // mrope theta calculations | ||
| 4002 | // note: the rest is the same as kernel_rope_neox | ||
| 4003 | const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3; | ||
| 4004 | const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1 | ||
| 4005 | const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 | ||
| 4006 | const int sector = ic % sect_dims; | ||
| 4007 | |||
| 4008 | float theta_base; | ||
| 4009 | if (FC_rope_is_imrope) { | ||
| 4010 | if (sector % 3 == 1 && sector < 3 * args.sect_1) { // h | ||
| 4011 | theta_base = (float) pos[i2 + args.ne02 * 1]; | ||
| 4012 | } else if (sector % 3 == 2 && sector < 3 * args.sect_2) { // w | ||
| 4013 | theta_base = (float) pos[i2 + args.ne02 * 2]; | ||
| 4014 | } else if (sector % 3 == 0 && sector < 3 * args.sect_0) { // t | ||
| 4015 | theta_base = (float) pos[i2 + args.ne02 * 0]; | ||
| 4016 | } else { // e | ||
| 4017 | theta_base = (float) pos[i2 + args.ne02 * 3]; | ||
| 4018 | } | ||
| 4019 | } else { | ||
| 4020 | if (sector < args.sect_0) { | ||
| 4021 | theta_base = (float) pos[i2]; | ||
| 4022 | } else if (sector < sec_w01) { | ||
| 4023 | theta_base = (float) pos[i2 + args.ne02 * 1]; | ||
| 4024 | } else if (sector < sec_w012) { | ||
| 4025 | theta_base = (float) pos[i2 + args.ne02 * 2]; | ||
| 4026 | } else { | ||
| 4027 | theta_base = (float) pos[i2 + args.ne02 * 3]; | ||
| 4028 | } | ||
| 4029 | } | ||
| 4030 | // end of mrope | ||
| 4031 | |||
| 4032 | const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); | ||
| 4033 | |||
| 4034 | const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; | ||
| 4035 | |||
| 4036 | rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); | ||
| 4037 | |||
| 4038 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); | ||
| 4039 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); | ||
| 4040 | |||
| 4041 | const float x0 = src[0]; | ||
| 4042 | const float x1 = src[args.n_dims/2]; | ||
| 4043 | |||
| 4044 | dst_data[0] = x0*cos_theta - x1*sin_theta; | ||
| 4045 | dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; | ||
| 4046 | } else { | ||
| 4047 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); | ||
| 4048 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 4049 | |||
| 4050 | dst_data[0] = src[0]; | ||
| 4051 | dst_data[1] = src[1]; | ||
| 4052 | } | ||
| 4053 | } | ||
| 4054 | } | ||
| 4055 | |||
| 4056 | template<typename T> | ||
| 4057 | kernel void kernel_rope_vision( | ||
| 4058 | constant ggml_metal_kargs_rope & args, | ||
| 4059 | device const char * src0, | ||
| 4060 | device const char * src1, | ||
| 4061 | device const char * src2, | ||
| 4062 | device char * dst, | ||
| 4063 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 4064 | ushort3 tptg [[threads_per_threadgroup]], | ||
| 4065 | uint3 tgpig[[threadgroup_position_in_grid]]) { | ||
| 4066 | const int i3 = tgpig[2]; | ||
| 4067 | const int i2 = tgpig[1]; | ||
| 4068 | const int i1 = tgpig[0]; | ||
| 4069 | |||
| 4070 | float corr_dims[2]; | ||
| 4071 | rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); | ||
| 4072 | |||
| 4073 | device const int32_t * pos = (device const int32_t *) src1; | ||
| 4074 | |||
| 4075 | const float inv_ndims = -1.f/args.n_dims; | ||
| 4076 | |||
| 4077 | float cos_theta; | ||
| 4078 | float sin_theta; | ||
| 4079 | |||
| 4080 | for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { | ||
| 4081 | if (i0 < 2*args.n_dims) { // different from kernel_rope_multi | ||
| 4082 | const int ic = i0/2; | ||
| 4083 | |||
| 4084 | // mrope theta calculations (only support 2 dimensions) | ||
| 4085 | const int sect_dims = args.sect_0 + args.sect_1; | ||
| 4086 | const int sector = ic % sect_dims; | ||
| 4087 | |||
| 4088 | float p; | ||
| 4089 | float theta_base; | ||
| 4090 | if (sector < args.sect_1) { | ||
| 4091 | p = (float) sector; | ||
| 4092 | theta_base = (float) pos[i2]; | ||
| 4093 | } else { | ||
| 4094 | p = (float) sector - args.sect_0; | ||
| 4095 | theta_base = (float) pos[i2 + args.ne02]; | ||
| 4096 | } | ||
| 4097 | |||
| 4098 | const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); | ||
| 4099 | // end of mrope | ||
| 4100 | |||
| 4101 | const float freq_factor = args.src2 ? ((device const float *) src2)[ic] : 1.0f; | ||
| 4102 | |||
| 4103 | rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); | ||
| 4104 | |||
| 4105 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); | ||
| 4106 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); | ||
| 4107 | |||
| 4108 | const float x0 = src[0]; | ||
| 4109 | const float x1 = src[args.n_dims]; // different from kernel_rope_multi | ||
| 4110 | |||
| 4111 | dst_data[0] = x0*cos_theta - x1*sin_theta; | ||
| 4112 | dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi | ||
| 4113 | } else { | ||
| 4114 | device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); | ||
| 4115 | device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 4116 | |||
| 4117 | dst_data[0] = src[0]; | ||
| 4118 | dst_data[1] = src[1]; | ||
| 4119 | } | ||
| 4120 | } | ||
| 4121 | } | ||
| 4122 | |||
| 4123 | typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t; | ||
| 4124 | typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t; | ||
| 4125 | typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t; | ||
| 4126 | typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t; | ||
| 4127 | |||
| 4128 | template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>; | ||
| 4129 | template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>; | ||
| 4130 | |||
| 4131 | template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>; | ||
| 4132 | template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>; | ||
| 4133 | |||
| 4134 | template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>; | ||
| 4135 | template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>; | ||
| 4136 | |||
| 4137 | template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>; | ||
| 4138 | template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>; | ||
| 4139 | |||
| 4140 | typedef void (im2col_t)( | ||
| 4141 | constant ggml_metal_kargs_im2col & args, | ||
| 4142 | device const float * x, | ||
| 4143 | device char * dst, | ||
| 4144 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4145 | uint3 tgpg[[threadgroups_per_grid]], | ||
| 4146 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4147 | uint3 ntg[[threads_per_threadgroup]]); | ||
| 4148 | |||
| 4149 | template <typename T> | ||
| 4150 | kernel void kernel_im2col( | ||
| 4151 | constant ggml_metal_kargs_im2col & args, | ||
| 4152 | device const float * x, | ||
| 4153 | device char * dst, | ||
| 4154 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4155 | uint3 tgpg[[threadgroups_per_grid]], | ||
| 4156 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4157 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4158 | // const int64_t IC = tgpg[0]; | ||
| 4159 | const int64_t OH = tgpg[1]; | ||
| 4160 | const int64_t OW = tgpg[2]; | ||
| 4161 | |||
| 4162 | const int64_t KH = ntg[1]; | ||
| 4163 | const int64_t KW = ntg[2]; | ||
| 4164 | |||
| 4165 | int64_t in = tpitg[0]; | ||
| 4166 | const int64_t ikh = tpitg[1]; | ||
| 4167 | const int64_t ikw = tpitg[2]; | ||
| 4168 | |||
| 4169 | const int64_t iic = tgpig[0]; | ||
| 4170 | const int64_t ioh = tgpig[1]; | ||
| 4171 | const int64_t iow = tgpig[2]; | ||
| 4172 | |||
| 4173 | const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0; | ||
| 4174 | const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1; | ||
| 4175 | |||
| 4176 | int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw); | ||
| 4177 | |||
| 4178 | device T * pdst = (device T *) (dst); | ||
| 4179 | |||
| 4180 | if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { | ||
| 4181 | while (in < args.N) { | ||
| 4182 | pdst[offset_dst] = 0.0f; | ||
| 4183 | offset_dst += ntg[0]*args.CHW*OH*OW; | ||
| 4184 | |||
| 4185 | in += ntg[0]; | ||
| 4186 | } | ||
| 4187 | } else { | ||
| 4188 | int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw; | ||
| 4189 | |||
| 4190 | while (in < args.N) { | ||
| 4191 | pdst[offset_dst] = x[offset_src]; | ||
| 4192 | |||
| 4193 | offset_dst += ntg[0]*args.CHW*OH*OW; | ||
| 4194 | offset_src += ntg[0]*args.ofs0; | ||
| 4195 | |||
| 4196 | in += ntg[0]; | ||
| 4197 | } | ||
| 4198 | } | ||
| 4199 | } | ||
| 4200 | |||
| 4201 | template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; | ||
| 4202 | template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; | ||
| 4203 | |||
| 4204 | // TODO: obolete -- remove | ||
| 4205 | //typedef void (im2col_ext_t)( | ||
| 4206 | // constant ggml_metal_kargs_im2col & args, | ||
| 4207 | // device const float * x, | ||
| 4208 | // device char * dst, | ||
| 4209 | // uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4210 | // uint3 tgpg[[threadgroups_per_grid]], | ||
| 4211 | // uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4212 | // uint3 ntg[[threads_per_threadgroup]]); | ||
| 4213 | // | ||
| 4214 | //template <typename T> | ||
| 4215 | //kernel void kernel_im2col_ext( | ||
| 4216 | // constant ggml_metal_kargs_im2col & args, | ||
| 4217 | // device const float * x, | ||
| 4218 | // device char * dst, | ||
| 4219 | // uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4220 | // uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW | ||
| 4221 | // uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4222 | // uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] | ||
| 4223 | // const int64_t KHW = (int64_t)args.KHW; | ||
| 4224 | // | ||
| 4225 | // const int64_t d = tgpig[0] / args.CHW; | ||
| 4226 | // const int64_t chw = tgpig[0] % args.CHW; | ||
| 4227 | // const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) | ||
| 4228 | // const int64_t HW = tgpig[0] % KHW; | ||
| 4229 | // | ||
| 4230 | // const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; | ||
| 4231 | // if (tpitg_0 >= args.N) { | ||
| 4232 | // return; | ||
| 4233 | // } | ||
| 4234 | // | ||
| 4235 | // const int64_t tpitg_1 = HW / args.KW; | ||
| 4236 | // const int64_t tpitg_2 = HW % args.KW; | ||
| 4237 | // | ||
| 4238 | // const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; | ||
| 4239 | // const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; | ||
| 4240 | // | ||
| 4241 | // const int64_t offset_dst = | ||
| 4242 | // (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + | ||
| 4243 | // (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); | ||
| 4244 | // | ||
| 4245 | // device T * pdst = (device T *) (dst); | ||
| 4246 | // | ||
| 4247 | // if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { | ||
| 4248 | // pdst[offset_dst] = 0.0f; | ||
| 4249 | // } else { | ||
| 4250 | // const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; | ||
| 4251 | // pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; | ||
| 4252 | // } | ||
| 4253 | //} | ||
| 4254 | // | ||
| 4255 | //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; | ||
| 4256 | //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; | ||
| 4257 | |||
| 4258 | template <typename TK> | ||
| 4259 | kernel void kernel_conv_2d( | ||
| 4260 | constant ggml_metal_kargs_conv_2d & args, | ||
| 4261 | device const char * weights, | ||
| 4262 | device const char * src, | ||
| 4263 | device char * dst, | ||
| 4264 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4265 | uint3 tgpg[[threadgroups_per_grid]], | ||
| 4266 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4267 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4268 | |||
| 4269 | const uint threads_per_tg = ntg.x * ntg.y * ntg.z; | ||
| 4270 | const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x; | ||
| 4271 | const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x; | ||
| 4272 | const uint thread_index = tg_index * threads_per_tg + local_thread; | ||
| 4273 | const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z; | ||
| 4274 | const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW; | ||
| 4275 | |||
| 4276 | for (uint64_t index = thread_index; index < total_outputs; index += total_threads) { | ||
| 4277 | uint64_t tmp = index; | ||
| 4278 | |||
| 4279 | const int32_t ow = tmp % args.OW; tmp /= args.OW; | ||
| 4280 | const int32_t oh = tmp % args.OH; tmp /= args.OH; | ||
| 4281 | const int32_t oc = tmp % args.OC; tmp /= args.OC; | ||
| 4282 | const int32_t n = tmp; | ||
| 4283 | |||
| 4284 | float acc = 0.0f; | ||
| 4285 | |||
| 4286 | const int32_t base_x = ow*args.s0 - args.p0; | ||
| 4287 | const int32_t base_y = oh*args.s1 - args.p1; | ||
| 4288 | |||
| 4289 | int32_t ky_start = 0; | ||
| 4290 | if (base_y < 0) { | ||
| 4291 | ky_start = (-base_y + args.d1 - 1)/args.d1; | ||
| 4292 | } | ||
| 4293 | int32_t ky_end = args.KH; | ||
| 4294 | const int32_t y_max = args.IH - 1 - base_y; | ||
| 4295 | if (y_max < 0) { | ||
| 4296 | ky_end = ky_start; | ||
| 4297 | } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) { | ||
| 4298 | ky_end = min(ky_end, y_max/args.d1 + 1); | ||
| 4299 | } | ||
| 4300 | |||
| 4301 | int32_t kx_start = 0; | ||
| 4302 | if (base_x < 0) { | ||
| 4303 | kx_start = (-base_x + args.d0 - 1)/args.d0; | ||
| 4304 | } | ||
| 4305 | int32_t kx_end = args.KW; | ||
| 4306 | const int32_t x_max = args.IW - 1 - base_x; | ||
| 4307 | if (x_max < 0) { | ||
| 4308 | kx_end = kx_start; | ||
| 4309 | } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) { | ||
| 4310 | kx_end = min(kx_end, x_max/args.d0 + 1); | ||
| 4311 | } | ||
| 4312 | |||
| 4313 | if (ky_start < ky_end && kx_start < kx_end) { | ||
| 4314 | const uint64_t src_base_n = (uint64_t) n * args.nb13; | ||
| 4315 | const uint64_t w_base_oc = (uint64_t) oc * args.nb03; | ||
| 4316 | |||
| 4317 | for (int32_t ic = 0; ic < args.IC; ++ic) { | ||
| 4318 | const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12; | ||
| 4319 | const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02; | ||
| 4320 | |||
| 4321 | for (int32_t ky = ky_start; ky < ky_end; ++ky) { | ||
| 4322 | const int32_t iy = base_y + ky*args.d1; | ||
| 4323 | const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11; | ||
| 4324 | const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01; | ||
| 4325 | |||
| 4326 | for (int32_t kx = kx_start; kx < kx_end; ++kx) { | ||
| 4327 | const int32_t ix = base_x + kx*args.d0; | ||
| 4328 | const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10; | ||
| 4329 | const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00; | ||
| 4330 | |||
| 4331 | const float x = *(device const float *)(src + src_offs); | ||
| 4332 | const float w = (float) (*(device const TK *)(weights + w_offs)); | ||
| 4333 | |||
| 4334 | acc += x * w; | ||
| 4335 | } | ||
| 4336 | } | ||
| 4337 | } | ||
| 4338 | } | ||
| 4339 | |||
| 4340 | const uint64_t dst_offs = | ||
| 4341 | (uint64_t) n * args.nb3 + | ||
| 4342 | (uint64_t) oc * args.nb2 + | ||
| 4343 | (uint64_t) oh * args.nb1 + | ||
| 4344 | (uint64_t) ow * args.nb0; | ||
| 4345 | |||
| 4346 | *(device float *)(dst + dst_offs) = acc; | ||
| 4347 | } | ||
| 4348 | } | ||
| 4349 | |||
| 4350 | template [[host_name("kernel_conv_2d_f32_f32")]] | ||
| 4351 | kernel void kernel_conv_2d<float>( | ||
| 4352 | constant ggml_metal_kargs_conv_2d & args, | ||
| 4353 | device const char * weights, | ||
| 4354 | device const char * src, | ||
| 4355 | device char * dst, | ||
| 4356 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4357 | uint3 tgpg[[threadgroups_per_grid]], | ||
| 4358 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4359 | uint3 ntg[[threads_per_threadgroup]]); | ||
| 4360 | |||
| 4361 | template [[host_name("kernel_conv_2d_f16_f32")]] | ||
| 4362 | kernel void kernel_conv_2d<half>( | ||
| 4363 | constant ggml_metal_kargs_conv_2d & args, | ||
| 4364 | device const char * weights, | ||
| 4365 | device const char * src, | ||
| 4366 | device char * dst, | ||
| 4367 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4368 | uint3 tgpg[[threadgroups_per_grid]], | ||
| 4369 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4370 | uint3 ntg[[threads_per_threadgroup]]); | ||
| 4371 | |||
| 4372 | typedef void (conv_transpose_1d_t)( | ||
| 4373 | constant ggml_metal_kargs_conv_transpose_1d & args, | ||
| 4374 | device const float * src0, | ||
| 4375 | device const float * src1, | ||
| 4376 | device char * dst, | ||
| 4377 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4378 | uint3 tgpg[[threadgroups_per_grid]]); | ||
| 4379 | |||
| 4380 | template <typename T> | ||
| 4381 | kernel void kernel_conv_transpose_1d( | ||
| 4382 | constant ggml_metal_kargs_conv_transpose_1d & args, | ||
| 4383 | device const T * src0, | ||
| 4384 | device const float * src1, | ||
| 4385 | device char * dst, | ||
| 4386 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4387 | uint3 tgpg[[threadgroups_per_grid]]) { | ||
| 4388 | |||
| 4389 | float v = 0.0f; | ||
| 4390 | |||
| 4391 | for (int64_t c = 0; c < args.IC; c++) { | ||
| 4392 | const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; | ||
| 4393 | const int32_t input_offset = c * args.IL; | ||
| 4394 | |||
| 4395 | for (int64_t i = 0; i < args.IL; i++) { | ||
| 4396 | if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { | ||
| 4397 | v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; | ||
| 4398 | } | ||
| 4399 | } | ||
| 4400 | } | ||
| 4401 | |||
| 4402 | device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1); | ||
| 4403 | |||
| 4404 | dst_ptr[0] = v; | ||
| 4405 | } | ||
| 4406 | |||
| 4407 | template [[host_name("kernel_conv_transpose_1d_f32_f32")]] | ||
| 4408 | kernel void kernel_conv_transpose_1d<float>( | ||
| 4409 | constant ggml_metal_kargs_conv_transpose_1d & args, | ||
| 4410 | device const float * src0, | ||
| 4411 | device const float * src1, | ||
| 4412 | device char * dst, | ||
| 4413 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4414 | uint3 tgpg[[threadgroups_per_grid]]); | ||
| 4415 | |||
| 4416 | template [[host_name("kernel_conv_transpose_1d_f16_f32")]] | ||
| 4417 | kernel void kernel_conv_transpose_1d<half>( | ||
| 4418 | constant ggml_metal_kargs_conv_transpose_1d & args, | ||
| 4419 | device const half * src0, | ||
| 4420 | device const float * src1, | ||
| 4421 | device char * dst, | ||
| 4422 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4423 | uint3 tgpg[[threadgroups_per_grid]]); | ||
| 4424 | |||
| 4425 | |||
| 4426 | typedef void (conv_transpose_2d_t)( | ||
| 4427 | constant ggml_metal_kargs_conv_transpose_2d & args, | ||
| 4428 | device const float * src0, | ||
| 4429 | device const float * src1, | ||
| 4430 | device char * dst, | ||
| 4431 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4432 | uint3 tgpg[[threadgroups_per_grid]]); | ||
| 4433 | |||
| 4434 | template <typename T> | ||
| 4435 | kernel void kernel_conv_transpose_2d( | ||
| 4436 | constant ggml_metal_kargs_conv_transpose_2d & args, | ||
| 4437 | device const T * src0, | ||
| 4438 | device const float * src1, | ||
| 4439 | device char * dst, | ||
| 4440 | threadgroup float * shared_sum [[threadgroup(0)]], | ||
| 4441 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4442 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4443 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4444 | |||
| 4445 | const int64_t out_x = tgpig[0]; | ||
| 4446 | const int64_t out_y = tgpig[1]; | ||
| 4447 | const int64_t out_c = tgpig[2]; | ||
| 4448 | |||
| 4449 | const int64_t kw = tpitg[0]; | ||
| 4450 | const int64_t kh = tpitg[1]; | ||
| 4451 | |||
| 4452 | float v = 0.0f; | ||
| 4453 | |||
| 4454 | for (int64_t in_c = 0; in_c < args.IC; in_c++) { | ||
| 4455 | int64_t in_y = out_y - kh; | ||
| 4456 | |||
| 4457 | if (in_y < 0 || in_y % args.s0) continue; | ||
| 4458 | |||
| 4459 | in_y /= args.s0; | ||
| 4460 | |||
| 4461 | if (in_y >= args.IH) continue; | ||
| 4462 | |||
| 4463 | int64_t in_x = out_x - kw; | ||
| 4464 | |||
| 4465 | if (in_x < 0 || in_x % args.s0) continue; | ||
| 4466 | |||
| 4467 | in_x /= args.s0; | ||
| 4468 | |||
| 4469 | if (in_x >= args.IW) continue; | ||
| 4470 | |||
| 4471 | const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x; | ||
| 4472 | const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw; | ||
| 4473 | |||
| 4474 | v += (float)src0[kernel_idx] * src1[input_idx]; | ||
| 4475 | } | ||
| 4476 | |||
| 4477 | const uint tid = tpitg.y * ntg.x + tpitg.x; | ||
| 4478 | shared_sum[tid] = v; | ||
| 4479 | |||
| 4480 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 4481 | |||
| 4482 | if (tid == 0) { | ||
| 4483 | float total = 0.0f; | ||
| 4484 | const uint num_threads = ntg.x * ntg.y; | ||
| 4485 | for (uint i = 0; i < num_threads; i++) { | ||
| 4486 | total += shared_sum[i]; | ||
| 4487 | } | ||
| 4488 | |||
| 4489 | device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2); | ||
| 4490 | dst_ptr[0] = total; | ||
| 4491 | } | ||
| 4492 | } | ||
| 4493 | |||
| 4494 | template [[host_name("kernel_conv_transpose_2d_f32_f32")]] | ||
| 4495 | kernel void kernel_conv_transpose_2d<float>( | ||
| 4496 | constant ggml_metal_kargs_conv_transpose_2d & args, | ||
| 4497 | device const float * src0, | ||
| 4498 | device const float * src1, | ||
| 4499 | device char * dst, | ||
| 4500 | threadgroup float * shared_sum [[threadgroup(0)]], | ||
| 4501 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4502 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4503 | uint3 ntg[[threads_per_threadgroup]]); | ||
| 4504 | |||
| 4505 | template [[host_name("kernel_conv_transpose_2d_f16_f32")]] | ||
| 4506 | kernel void kernel_conv_transpose_2d<half>( | ||
| 4507 | constant ggml_metal_kargs_conv_transpose_2d & args, | ||
| 4508 | device const half * src0, | ||
| 4509 | device const float * src1, | ||
| 4510 | device char * dst, | ||
| 4511 | threadgroup float * shared_sum [[threadgroup(0)]], | ||
| 4512 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4513 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4514 | uint3 ntg[[threads_per_threadgroup]]); | ||
| 4515 | |||
| 4516 | kernel void kernel_upscale_f32( | ||
| 4517 | constant ggml_metal_kargs_upscale & args, | ||
| 4518 | device const char * src0, | ||
| 4519 | device char * dst, | ||
| 4520 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4521 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4522 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4523 | |||
| 4524 | const int64_t i3 = tgpig.z; | ||
| 4525 | const int64_t i2 = tgpig.y; | ||
| 4526 | const int64_t i1 = tgpig.x; | ||
| 4527 | |||
| 4528 | const int64_t i03 = i3/args.sf3; | ||
| 4529 | const int64_t i02 = i2/args.sf2; | ||
| 4530 | const int64_t i01 = i1/args.sf1; | ||
| 4531 | |||
| 4532 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 4533 | const int64_t i00 = i0/args.sf0; | ||
| 4534 | |||
| 4535 | device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); | ||
| 4536 | device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 4537 | |||
| 4538 | dst_ptr[0] = src0_ptr[0]; | ||
| 4539 | } | ||
| 4540 | } | ||
| 4541 | |||
| 4542 | kernel void kernel_pad_f32( | ||
| 4543 | constant ggml_metal_kargs_pad & args, | ||
| 4544 | device const char * src0, | ||
| 4545 | device char * dst, | ||
| 4546 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4547 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4548 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4549 | |||
| 4550 | const int64_t i3 = tgpig.z; | ||
| 4551 | const int64_t i2 = tgpig.y; | ||
| 4552 | const int64_t i1 = tgpig.x; | ||
| 4553 | |||
| 4554 | const int64_t i03 = i3; | ||
| 4555 | const int64_t i02 = i2; | ||
| 4556 | const int64_t i01 = i1; | ||
| 4557 | |||
| 4558 | device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); | ||
| 4559 | device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); | ||
| 4560 | |||
| 4561 | if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { | ||
| 4562 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 4563 | if (i0 < args.ne00) { | ||
| 4564 | dst_ptr[i0] = src0_ptr[i0]; | ||
| 4565 | } else { | ||
| 4566 | dst_ptr[i0] = 0.0f; | ||
| 4567 | } | ||
| 4568 | } | ||
| 4569 | |||
| 4570 | return; | ||
| 4571 | } | ||
| 4572 | |||
| 4573 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 4574 | dst_ptr[i0] = 0.0f; | ||
| 4575 | } | ||
| 4576 | } | ||
| 4577 | |||
| 4578 | kernel void kernel_pad_reflect_1d_f32( | ||
| 4579 | constant ggml_metal_kargs_pad_reflect_1d & args, | ||
| 4580 | device const char * src0, | ||
| 4581 | device char * dst, | ||
| 4582 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4583 | uint3 tgpg[[threadgroups_per_grid]], | ||
| 4584 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4585 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4586 | |||
| 4587 | const int64_t i3 = tgpig.z; | ||
| 4588 | const int64_t i2 = tgpig.y; | ||
| 4589 | const int64_t i1 = tgpig.x; | ||
| 4590 | |||
| 4591 | const int64_t i03 = i3; | ||
| 4592 | const int64_t i02 = i2; | ||
| 4593 | const int64_t i01 = i1; | ||
| 4594 | |||
| 4595 | device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); | ||
| 4596 | device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); | ||
| 4597 | |||
| 4598 | if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { | ||
| 4599 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 4600 | if (i0 < args.p0) { | ||
| 4601 | dst_ptr[i0] = src0_ptr[args.p0 - i0]; | ||
| 4602 | } else if (i0 < args.ne0 - args.p1) { | ||
| 4603 | dst_ptr[i0] = src0_ptr[i0 - args.p0]; | ||
| 4604 | } else { | ||
| 4605 | dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1]; | ||
| 4606 | } | ||
| 4607 | } | ||
| 4608 | } | ||
| 4609 | } | ||
| 4610 | |||
| 4611 | kernel void kernel_arange_f32( | ||
| 4612 | constant ggml_metal_kargs_arange & args, | ||
| 4613 | device char * dst, | ||
| 4614 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4615 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4616 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4617 | |||
| 4618 | device float * dst_ptr = (device float *) dst; | ||
| 4619 | |||
| 4620 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 4621 | dst_ptr[i0] = args.start + args.step * i0; | ||
| 4622 | } | ||
| 4623 | } | ||
| 4624 | |||
| 4625 | kernel void kernel_timestep_embedding_f32( | ||
| 4626 | constant ggml_metal_kargs_timestep_embedding & args, | ||
| 4627 | device const char * src0, | ||
| 4628 | device char * dst, | ||
| 4629 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4630 | uint3 tpitg[[thread_position_in_threadgroup]], | ||
| 4631 | uint3 ntg[[threads_per_threadgroup]]) { | ||
| 4632 | |||
| 4633 | int i = tgpig.x; | ||
| 4634 | device float * embed_data = (device float *)(dst + i*args.nb1); | ||
| 4635 | |||
| 4636 | int half_ = args.dim / 2; | ||
| 4637 | for (int j = tpitg.x; j < half_; j += ntg.x) { | ||
| 4638 | float timestep = ((device float *)src0)[i]; | ||
| 4639 | float freq = (float)exp(-log((float)args.max_period) * j / half_); | ||
| 4640 | float arg = timestep * freq; | ||
| 4641 | embed_data[j ] = cos(arg); | ||
| 4642 | embed_data[j + half_] = sin(arg); | ||
| 4643 | } | ||
| 4644 | |||
| 4645 | if (args.dim % 2 != 0 && tpitg.x == 0) { | ||
| 4646 | embed_data[2 * half_] = 0.f; | ||
| 4647 | } | ||
| 4648 | } | ||
| 4649 | |||
| 4650 | // bitonic sort implementation following the CUDA kernels as reference | ||
| 4651 | typedef void (argsort_t)( | ||
| 4652 | constant ggml_metal_kargs_argsort & args, | ||
| 4653 | device const char * src0, | ||
| 4654 | device int32_t * dst, | ||
| 4655 | threadgroup int32_t * shmem_i32 [[threadgroup(0)]], | ||
| 4656 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4657 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 4658 | ushort3 ntg[[threads_per_threadgroup]]); | ||
| 4659 | |||
| 4660 | template<ggml_sort_order order> | ||
| 4661 | kernel void kernel_argsort_f32_i32( | ||
| 4662 | constant ggml_metal_kargs_argsort & args, | ||
| 4663 | device const char * src0, | ||
| 4664 | device int32_t * dst, | ||
| 4665 | threadgroup int32_t * shmem_i32 [[threadgroup(0)]], | ||
| 4666 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4667 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 4668 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 4669 | // bitonic sort | ||
| 4670 | const int col = tpitg[0]; | ||
| 4671 | const int ib = tgpig[0] / args.ne01; | ||
| 4672 | |||
| 4673 | const int i00 = ib*ntg.x; | ||
| 4674 | const int i01 = tgpig[0] % args.ne01; | ||
| 4675 | const int i02 = tgpig[1]; | ||
| 4676 | const int i03 = tgpig[2]; | ||
| 4677 | |||
| 4678 | device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); | ||
| 4679 | |||
| 4680 | // initialize indices | ||
| 4681 | shmem_i32[col] = i00 + col; | ||
| 4682 | |||
| 4683 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 4684 | |||
| 4685 | for (int k = 2; k <= ntg.x; k *= 2) { | ||
| 4686 | for (int j = k / 2; j > 0; j /= 2) { | ||
| 4687 | int ixj = col ^ j; | ||
| 4688 | if (ixj > col) { | ||
| 4689 | if ((col & k) == 0) { | ||
| 4690 | if (shmem_i32[col] >= args.ne00 || | ||
| 4691 | (shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? | ||
| 4692 | src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] : | ||
| 4693 | src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]])) | ||
| 4694 | ) { | ||
| 4695 | SWAP(shmem_i32[col], shmem_i32[ixj]); | ||
| 4696 | } | ||
| 4697 | } else { | ||
| 4698 | if (shmem_i32[ixj] >= args.ne00 || | ||
| 4699 | (shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? | ||
| 4700 | src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] : | ||
| 4701 | src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]])) | ||
| 4702 | ) { | ||
| 4703 | SWAP(shmem_i32[col], shmem_i32[ixj]); | ||
| 4704 | } | ||
| 4705 | } | ||
| 4706 | } | ||
| 4707 | |||
| 4708 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 4709 | } | ||
| 4710 | } | ||
| 4711 | |||
| 4712 | const int64_t i0 = ib*args.top_k; | ||
| 4713 | |||
| 4714 | // copy the result to dst without the padding | ||
| 4715 | if (i0 + col < args.ne0 && col < args.top_k) { | ||
| 4716 | dst += i0 + args.ne0*i01 + args.ne0*args.ne1*i02 + args.ne0*args.ne1*args.ne2*i03; | ||
| 4717 | |||
| 4718 | dst[col] = shmem_i32[col]; | ||
| 4719 | } | ||
| 4720 | } | ||
| 4721 | |||
| 4722 | template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>; | ||
| 4723 | template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>; | ||
| 4724 | |||
| 4725 | typedef void (argsort_merge_t)( | ||
| 4726 | constant ggml_metal_kargs_argsort_merge & args, | ||
| 4727 | device const char * src0, | ||
| 4728 | device const int32_t * tmp, | ||
| 4729 | device int32_t * dst, | ||
| 4730 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4731 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 4732 | ushort3 ntg[[threads_per_threadgroup]]); | ||
| 4733 | |||
| 4734 | template<ggml_sort_order order> | ||
| 4735 | kernel void kernel_argsort_merge_f32_i32( | ||
| 4736 | constant ggml_metal_kargs_argsort_merge & args, | ||
| 4737 | device const char * src0, | ||
| 4738 | device const int32_t * tmp, | ||
| 4739 | device int32_t * dst, | ||
| 4740 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4741 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 4742 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 4743 | |||
| 4744 | const int im = tgpig[0] / args.ne01; | ||
| 4745 | const int i01 = tgpig[0] % args.ne01; | ||
| 4746 | const int i02 = tgpig[1]; | ||
| 4747 | const int i03 = tgpig[2]; | ||
| 4748 | |||
| 4749 | const int start = im * (2 * args.len); | ||
| 4750 | |||
| 4751 | const int len0 = MIN(args.len, MAX(0, args.ne0 - (int)(start))); | ||
| 4752 | const int len1 = MIN(args.len, MAX(0, args.ne0 - (int)(start + args.len))); | ||
| 4753 | |||
| 4754 | const int total = len0 + len1; | ||
| 4755 | |||
| 4756 | device const int32_t * tmp0 = tmp + start | ||
| 4757 | + i01*args.ne0 | ||
| 4758 | + i02*args.ne0*args.ne01 | ||
| 4759 | + i03*args.ne0*args.ne01*args.ne02; | ||
| 4760 | |||
| 4761 | device const int32_t * tmp1 = tmp0 + args.len; | ||
| 4762 | |||
| 4763 | dst += start | ||
| 4764 | + i01*args.top_k | ||
| 4765 | + i02*args.top_k*args.ne01 | ||
| 4766 | + i03*args.top_k*args.ne01*args.ne02; | ||
| 4767 | |||
| 4768 | device const float * src0_row = (device const float *)(src0 | ||
| 4769 | + args.nb01*i01 | ||
| 4770 | + args.nb02*i02 | ||
| 4771 | + args.nb03*i03); | ||
| 4772 | |||
| 4773 | if (total == 0) { | ||
| 4774 | return; | ||
| 4775 | } | ||
| 4776 | |||
| 4777 | const int chunk = (total + ntg.x - 1) / ntg.x; | ||
| 4778 | |||
| 4779 | const int k0 = tpitg.x * chunk; | ||
| 4780 | const int k1 = MIN(MIN(k0 + chunk, total), args.top_k); | ||
| 4781 | |||
| 4782 | if (k0 >= args.top_k) { | ||
| 4783 | return; | ||
| 4784 | } | ||
| 4785 | |||
| 4786 | if (k0 >= total) { | ||
| 4787 | return; | ||
| 4788 | } | ||
| 4789 | |||
| 4790 | int low = k0 > len1 ? k0 - len1 : 0; | ||
| 4791 | int high = MIN(k0, len0); | ||
| 4792 | |||
| 4793 | // binary-search partition (i, j) such that i + j = k | ||
| 4794 | while (low < high) { | ||
| 4795 | const int mid = (low + high) >> 1; | ||
| 4796 | |||
| 4797 | const int32_t idx0 = tmp0[mid]; | ||
| 4798 | const int32_t idx1 = tmp1[k0 - mid - 1]; | ||
| 4799 | |||
| 4800 | const float val0 = src0_row[idx0]; | ||
| 4801 | const float val1 = src0_row[idx1]; | ||
| 4802 | |||
| 4803 | bool take_left; | ||
| 4804 | if (order == GGML_SORT_ORDER_ASC) { | ||
| 4805 | take_left = (val0 <= val1); | ||
| 4806 | } else { | ||
| 4807 | take_left = (val0 >= val1); | ||
| 4808 | } | ||
| 4809 | |||
| 4810 | if (take_left) { | ||
| 4811 | low = mid + 1; | ||
| 4812 | } else { | ||
| 4813 | high = mid; | ||
| 4814 | } | ||
| 4815 | } | ||
| 4816 | |||
| 4817 | int i = low; | ||
| 4818 | int j = k0 - i; | ||
| 4819 | |||
| 4820 | // keep the merge fronts into registers | ||
| 4821 | int32_t idx0 = 0; | ||
| 4822 | float val0 = 0.0f; | ||
| 4823 | if (i < len0) { | ||
| 4824 | idx0 = tmp0[i]; | ||
| 4825 | val0 = src0_row[idx0]; | ||
| 4826 | } | ||
| 4827 | |||
| 4828 | int32_t idx1 = 0; | ||
| 4829 | float val1 = 0.0f; | ||
| 4830 | if (j < len1) { | ||
| 4831 | idx1 = tmp1[j]; | ||
| 4832 | val1 = src0_row[idx1]; | ||
| 4833 | } | ||
| 4834 | |||
| 4835 | for (int k = k0; k < k1; ++k) { | ||
| 4836 | int32_t out_idx; | ||
| 4837 | |||
| 4838 | if (i >= len0) { | ||
| 4839 | while (k < k1) { | ||
| 4840 | dst[k++] = tmp1[j++]; | ||
| 4841 | } | ||
| 4842 | break; | ||
| 4843 | } else if (j >= len1) { | ||
| 4844 | while (k < k1) { | ||
| 4845 | dst[k++] = tmp0[i++]; | ||
| 4846 | } | ||
| 4847 | break; | ||
| 4848 | } else { | ||
| 4849 | bool take_left; | ||
| 4850 | |||
| 4851 | if (order == GGML_SORT_ORDER_ASC) { | ||
| 4852 | take_left = (val0 <= val1); | ||
| 4853 | } else { | ||
| 4854 | take_left = (val0 >= val1); | ||
| 4855 | } | ||
| 4856 | |||
| 4857 | if (take_left) { | ||
| 4858 | out_idx = idx0; | ||
| 4859 | ++i; | ||
| 4860 | if (i < len0) { | ||
| 4861 | idx0 = tmp0[i]; | ||
| 4862 | val0 = src0_row[idx0]; | ||
| 4863 | } | ||
| 4864 | } else { | ||
| 4865 | out_idx = idx1; | ||
| 4866 | ++j; | ||
| 4867 | if (j < len1) { | ||
| 4868 | idx1 = tmp1[j]; | ||
| 4869 | val1 = src0_row[idx1]; | ||
| 4870 | } | ||
| 4871 | } | ||
| 4872 | } | ||
| 4873 | |||
| 4874 | dst[k] = out_idx; | ||
| 4875 | } | ||
| 4876 | } | ||
| 4877 | |||
| 4878 | template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>; | ||
| 4879 | template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>; | ||
| 4880 | |||
| 4881 | constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; | ||
| 4882 | |||
| 4883 | constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; | ||
| 4884 | |||
| 4885 | // pad the last chunk of C elements of k and v into a an extra pad buffer | ||
| 4886 | kernel void kernel_flash_attn_ext_pad( | ||
| 4887 | constant ggml_metal_kargs_flash_attn_ext_pad & args, | ||
| 4888 | device const char * k, | ||
| 4889 | device const char * v, | ||
| 4890 | device const char * mask, | ||
| 4891 | device char * dst, | ||
| 4892 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4893 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 4894 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 4895 | const int32_t C = FC_flash_attn_ext_pad_ncpsg; | ||
| 4896 | |||
| 4897 | device char * k_pad = dst; | ||
| 4898 | device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; | ||
| 4899 | device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; | ||
| 4900 | |||
| 4901 | const int32_t icp = args.ne11 % C; | ||
| 4902 | const int32_t ic0 = args.ne11 - icp; | ||
| 4903 | |||
| 4904 | const int32_t i1 = tgpig[0]; | ||
| 4905 | const int32_t i2 = tgpig[1]; | ||
| 4906 | const int32_t i3 = tgpig[2]; | ||
| 4907 | |||
| 4908 | if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { | ||
| 4909 | device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; | ||
| 4910 | device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; | ||
| 4911 | |||
| 4912 | device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; | ||
| 4913 | device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; | ||
| 4914 | |||
| 4915 | if (i1 >= icp) { | ||
| 4916 | // here it is not important the exact value that will be used as we rely on masking out the scores in the attention | ||
| 4917 | for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { | ||
| 4918 | k_dst[i] = 0; | ||
| 4919 | } | ||
| 4920 | for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { | ||
| 4921 | v_dst[i] = 0; | ||
| 4922 | } | ||
| 4923 | } else { | ||
| 4924 | for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { | ||
| 4925 | k_dst[i] = k_src[i]; | ||
| 4926 | } | ||
| 4927 | for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { | ||
| 4928 | v_dst[i] = v_src[i]; | ||
| 4929 | } | ||
| 4930 | } | ||
| 4931 | } | ||
| 4932 | |||
| 4933 | if (FC_flash_attn_ext_pad_has_mask) { | ||
| 4934 | if (i2 < args.ne32 && i3 < args.ne33) { | ||
| 4935 | for (int ib = i1; ib < args.ne31; ib += C) { | ||
| 4936 | device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; | ||
| 4937 | device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; | ||
| 4938 | |||
| 4939 | for (int i = tiitg; i < C; i += ntg.x) { | ||
| 4940 | if (i >= icp) { | ||
| 4941 | mask_dst[i] = -MAXHALF; | ||
| 4942 | } else { | ||
| 4943 | mask_dst[i] = mask_src[i]; | ||
| 4944 | } | ||
| 4945 | } | ||
| 4946 | } | ||
| 4947 | } | ||
| 4948 | } | ||
| 4949 | } | ||
| 4950 | |||
| 4951 | constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; | ||
| 4952 | constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; | ||
| 4953 | |||
| 4954 | // scan the blocks of the mask that are not masked | ||
| 4955 | // 0 - masked (i.e. full of -INF, skip) | ||
| 4956 | // 1 - not masked (i.e. at least one element of the mask is not -INF) | ||
| 4957 | // 2 - all zero | ||
| 4958 | kernel void kernel_flash_attn_ext_blk( | ||
| 4959 | constant ggml_metal_kargs_flash_attn_ext_blk & args, | ||
| 4960 | device const char * mask, | ||
| 4961 | device char * dst, | ||
| 4962 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 4963 | ushort tiisg[[thread_index_in_simdgroup]]) { | ||
| 4964 | // block size C x Q | ||
| 4965 | const int32_t Q = FC_flash_attn_ext_blk_nqptg; | ||
| 4966 | const int32_t C = FC_flash_attn_ext_blk_ncpsg; | ||
| 4967 | |||
| 4968 | constexpr short NW = N_SIMDWIDTH; | ||
| 4969 | |||
| 4970 | const int32_t i3 = tgpig[2]/args.ne32; | ||
| 4971 | const int32_t i2 = tgpig[2]%args.ne32; | ||
| 4972 | const int32_t i1 = tgpig[1]; | ||
| 4973 | const int32_t i0 = tgpig[0]; | ||
| 4974 | |||
| 4975 | char res = i0*C + C > args.ne30 ? 1 : 0; | ||
| 4976 | |||
| 4977 | device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; | ||
| 4978 | |||
| 4979 | // detailed check of the elements of the block | ||
| 4980 | if ((C > NW || Q > 1) && res == 0) { | ||
| 4981 | half mmin = MAXHALF; | ||
| 4982 | half mmax = -MAXHALF; | ||
| 4983 | |||
| 4984 | FOR_UNROLL (short j = 0; j < Q; ++j) { | ||
| 4985 | FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { | ||
| 4986 | mmin = min(mmin, mask_src[ii*NW]); | ||
| 4987 | mmax = max(mmax, mask_src[ii*NW]); | ||
| 4988 | } | ||
| 4989 | |||
| 4990 | mask_src += args.nb31/2; | ||
| 4991 | } | ||
| 4992 | |||
| 4993 | mmin = simd_min(mmin); | ||
| 4994 | mmax = simd_max(mmax); | ||
| 4995 | |||
| 4996 | if (mmax > -MAXHALF) { | ||
| 4997 | if (mmin == 0.0 && mmax == 0.0) { | ||
| 4998 | res = 2; | ||
| 4999 | } else { | ||
| 5000 | res = 1; | ||
| 5001 | } | ||
| 5002 | } | ||
| 5003 | } | ||
| 5004 | |||
| 5005 | const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); | ||
| 5006 | const int32_t nblk0 = ((args.ne30 + C - 1)/C); | ||
| 5007 | |||
| 5008 | if (tiisg == 0) { | ||
| 5009 | dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; | ||
| 5010 | } | ||
| 5011 | } | ||
| 5012 | |||
| 5013 | constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; | ||
| 5014 | constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; | ||
| 5015 | constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; | ||
| 5016 | constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; | ||
| 5017 | constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; | ||
| 5018 | |||
| 5019 | constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; | ||
| 5020 | |||
| 5021 | //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; | ||
| 5022 | //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; | ||
| 5023 | //constant float FC_flash_attn_ext_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT + 12)]]; | ||
| 5024 | |||
| 5025 | constant int32_t FC_flash_attn_ext_ns10 [[function_constant(FC_FLASH_ATTN_EXT + 20)]]; | ||
| 5026 | constant int32_t FC_flash_attn_ext_ns20 [[function_constant(FC_FLASH_ATTN_EXT + 21)]]; | ||
| 5027 | constant int32_t FC_flash_attn_ext_nsg [[function_constant(FC_FLASH_ATTN_EXT + 22)]]; | ||
| 5028 | |||
| 5029 | // ref: https://arxiv.org/pdf/2307.08691.pdf | ||
| 5030 | template< | ||
| 5031 | typename q_t, // query types in shared memory | ||
| 5032 | typename q4_t, | ||
| 5033 | typename q8x8_t, | ||
| 5034 | typename k_t, // key types in shared memory | ||
| 5035 | typename k4x4_t, | ||
| 5036 | typename k8x8_t, | ||
| 5037 | typename v_t, // value types in shared memory | ||
| 5038 | typename v4x4_t, | ||
| 5039 | typename v8x8_t, | ||
| 5040 | typename qk_t, // Q*K types | ||
| 5041 | typename qk8x8_t, | ||
| 5042 | typename s_t, // soft-max types | ||
| 5043 | typename s2_t, | ||
| 5044 | typename s8x8_t, | ||
| 5045 | typename o_t, // attention accumulation types | ||
| 5046 | typename o4_t, | ||
| 5047 | typename o8x8_t, | ||
| 5048 | typename kd4x4_t, // key type in device memory | ||
| 5049 | short nl_k, | ||
| 5050 | void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), | ||
| 5051 | typename vd4x4_t, // value type in device memory | ||
| 5052 | short nl_v, | ||
| 5053 | void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), | ||
| 5054 | short DK, // K head size | ||
| 5055 | short DV, // V head size | ||
| 5056 | short Q, // queries per threadgroup | ||
| 5057 | short C, // cache items per threadgroup | ||
| 5058 | short NSG> // number of simd groups | ||
| 5059 | void kernel_flash_attn_ext_impl( | ||
| 5060 | constant ggml_metal_kargs_flash_attn_ext & args, | ||
| 5061 | device const char * q, | ||
| 5062 | device const char * k, | ||
| 5063 | device const char * v, | ||
| 5064 | device const char * mask, | ||
| 5065 | device const char * sinks, | ||
| 5066 | device const char * pad, | ||
| 5067 | device const char * blk, | ||
| 5068 | device char * dst, | ||
| 5069 | threadgroup half * shmem_f16, | ||
| 5070 | uint3 tgpig, | ||
| 5071 | ushort tiisg, | ||
| 5072 | ushort sgitg) { | ||
| 5073 | const ushort iq3 = tgpig[2]; | ||
| 5074 | const ushort iq2 = tgpig[1]; | ||
| 5075 | const ushort iq1 = tgpig[0]*Q; | ||
| 5076 | |||
| 5077 | #define NS10 (FC_flash_attn_ext_ns10) | ||
| 5078 | #define NS20 (FC_flash_attn_ext_ns20) | ||
| 5079 | |||
| 5080 | // note: I had some concerns that using this instead of the ugly macros above was affecting performance | ||
| 5081 | // need to re-check carefully and if no regressions are observerd - remove the macros | ||
| 5082 | // the concerns is that maybe using const variables requires extra registers? but not sure if the compiler | ||
| 5083 | // is clever enough to avoid this. unfortunately, using constexpr is not possible with FC | ||
| 5084 | //const short NS10 = FC_flash_attn_ext_ns10; | ||
| 5085 | //const short NS20 = FC_flash_attn_ext_ns20; | ||
| 5086 | |||
| 5087 | constexpr short KV = 8; | ||
| 5088 | |||
| 5089 | constexpr short DK4 = DK/4; | ||
| 5090 | constexpr short DK8 = DK/8; | ||
| 5091 | constexpr short DK16 = DK/16; | ||
| 5092 | constexpr short DV4 = DV/4; | ||
| 5093 | //constexpr short DV8 = DV/8; | ||
| 5094 | constexpr short DV16 = DV/16; | ||
| 5095 | |||
| 5096 | constexpr short PV = PAD2(DV, 64); | ||
| 5097 | constexpr short PV4 = PV/4; | ||
| 5098 | constexpr short PV8 = PV/8; | ||
| 5099 | //constexpr short PV16 = PV/16; | ||
| 5100 | |||
| 5101 | constexpr short NW = N_SIMDWIDTH; | ||
| 5102 | constexpr short NQ = Q/NSG; | ||
| 5103 | constexpr short SH = 2*C; // shared memory per simdgroup (s_t == float) | ||
| 5104 | |||
| 5105 | constexpr short TS = 2*SH; | ||
| 5106 | constexpr short T = DK + 2*PV; // shared memory size per query in (half) | ||
| 5107 | |||
| 5108 | threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*T); // holds the query data | ||
| 5109 | threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*T); // same as above but in q4_t | ||
| 5110 | threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*T + Q*DK); // the result for all queries in 8x8 matrices (the O matrix from the paper) | ||
| 5111 | threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*T + Q*DK); | ||
| 5112 | threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + Q*T); // scratch buffer for attention, mask and diagonal matrix | ||
| 5113 | threadgroup s2_t * ss2 = (threadgroup s2_t *) (shmem_f16 + Q*T); // same as above but in s2_t | ||
| 5114 | |||
| 5115 | threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load K in shared memory | ||
| 5116 | threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in k4x4_t | ||
| 5117 | |||
| 5118 | threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // scratch buffer to load V in shared memory | ||
| 5119 | threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T + Q*TS); // same as above but in v4x4_t | ||
| 5120 | |||
| 5121 | // mask storage in shared mem | ||
| 5122 | threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); | ||
| 5123 | |||
| 5124 | // per-query mask pointers | ||
| 5125 | device const half2 * pm2[NQ]; | ||
| 5126 | |||
| 5127 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5128 | const short j = jj*NSG + sgitg; | ||
| 5129 | |||
| 5130 | pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); | ||
| 5131 | } | ||
| 5132 | |||
| 5133 | { | ||
| 5134 | const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); | ||
| 5135 | const int32_t nblk0 = ((args.ne11 + C - 1)/C); | ||
| 5136 | |||
| 5137 | blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; | ||
| 5138 | } | ||
| 5139 | |||
| 5140 | { | ||
| 5141 | q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; | ||
| 5142 | |||
| 5143 | const short ikv2 = iq2/(args.ne02/args.ne_12_2); | ||
| 5144 | const short ikv3 = iq3/(args.ne03/args.ne_12_3); | ||
| 5145 | |||
| 5146 | k += ikv2*args.nb12 + ikv3*args.nb13; | ||
| 5147 | v += ikv2*args.nb22 + ikv3*args.nb23; | ||
| 5148 | } | ||
| 5149 | |||
| 5150 | // load heads from Q to shared memory | ||
| 5151 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5152 | const short j = jj*NSG + sgitg; | ||
| 5153 | |||
| 5154 | device const float4 * q4 = (device const float4 *) ((device const char *) q + j*args.nb01); | ||
| 5155 | |||
| 5156 | for (short i = tiisg; i < DK4; i += NW) { | ||
| 5157 | if (iq1 + j < args.ne01) { | ||
| 5158 | sq4[j*DK4 + i] = (q4_t) q4[i]; | ||
| 5159 | } else { | ||
| 5160 | sq4[j*DK4 + i] = 0; | ||
| 5161 | } | ||
| 5162 | } | ||
| 5163 | } | ||
| 5164 | |||
| 5165 | // zero out | ||
| 5166 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5167 | const short j = jj*NSG + sgitg; | ||
| 5168 | |||
| 5169 | for (short i = tiisg; i < DV4; i += NW) { | ||
| 5170 | so4[j*PV4 + i] = 0; | ||
| 5171 | } | ||
| 5172 | |||
| 5173 | for (short i = tiisg; i < SH; i += NW) { | ||
| 5174 | ss[j*SH + i] = 0.0f; | ||
| 5175 | } | ||
| 5176 | } | ||
| 5177 | |||
| 5178 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5179 | |||
| 5180 | float S[NQ] = { [0 ... NQ-1] = 0.0f }; | ||
| 5181 | |||
| 5182 | { | ||
| 5183 | float M[NQ] = { [0 ... NQ-1] = -FLT_MAX/2 }; | ||
| 5184 | |||
| 5185 | float slope = 1.0f; | ||
| 5186 | |||
| 5187 | // ALiBi | ||
| 5188 | if (FC_flash_attn_ext_has_bias) { | ||
| 5189 | const short h = iq2; | ||
| 5190 | |||
| 5191 | const float base = h < args.n_head_log2 ? args.m0 : args.m1; | ||
| 5192 | const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; | ||
| 5193 | |||
| 5194 | slope = pow(base, exph); | ||
| 5195 | } | ||
| 5196 | |||
| 5197 | // loop over the KV cache | ||
| 5198 | // each simdgroup handles blocks of Q rows and C columns | ||
| 5199 | for (int ic0 = 0; ; ++ic0) { | ||
| 5200 | int ic = ic0*C; | ||
| 5201 | if (ic >= args.ne11) { | ||
| 5202 | break; | ||
| 5203 | } | ||
| 5204 | |||
| 5205 | // the last partial chunk uses the pad buffer as source | ||
| 5206 | if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { | ||
| 5207 | k = pad; | ||
| 5208 | v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; | ||
| 5209 | mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; | ||
| 5210 | |||
| 5211 | const short ikv2 = iq2/(args.ne02/args.ne_12_2); | ||
| 5212 | const short ikv3 = iq3/(args.ne03/args.ne_12_3); | ||
| 5213 | |||
| 5214 | k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; | ||
| 5215 | v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; | ||
| 5216 | |||
| 5217 | if (!FC_flash_attn_ext_has_mask) { | ||
| 5218 | threadgroup half * sm = (threadgroup half *) (sm2); | ||
| 5219 | |||
| 5220 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5221 | const short j = jj*NSG + sgitg; | ||
| 5222 | |||
| 5223 | for (short i = tiisg; i < C; i += NW) { | ||
| 5224 | if (ic + i >= args.ne11) { | ||
| 5225 | sm[2*j*SH + i] = -MAXHALF; | ||
| 5226 | } | ||
| 5227 | } | ||
| 5228 | } | ||
| 5229 | } else { | ||
| 5230 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5231 | const short j = jj*NSG + sgitg; | ||
| 5232 | |||
| 5233 | pm2[jj] = (device const half2 *) ((device const half *) mask + | ||
| 5234 | (iq1 + j)*C + | ||
| 5235 | (iq2%args.ne32)*(C*args.ne31) + | ||
| 5236 | (iq3%args.ne33)*(C*args.ne31*args.ne32)); | ||
| 5237 | } | ||
| 5238 | } | ||
| 5239 | |||
| 5240 | ic = 0; | ||
| 5241 | } | ||
| 5242 | |||
| 5243 | char blk_cur = 1; | ||
| 5244 | |||
| 5245 | // read the mask into shared mem | ||
| 5246 | if (FC_flash_attn_ext_has_mask) { | ||
| 5247 | blk_cur = blk[ic0]; | ||
| 5248 | |||
| 5249 | if (blk_cur == 0) { | ||
| 5250 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5251 | pm2[jj] += NW; | ||
| 5252 | } | ||
| 5253 | |||
| 5254 | continue; | ||
| 5255 | } | ||
| 5256 | |||
| 5257 | if (blk_cur == 1) { | ||
| 5258 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5259 | const short j = jj*NSG + sgitg; | ||
| 5260 | |||
| 5261 | if (FC_flash_attn_ext_bc_mask) { | ||
| 5262 | sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); | ||
| 5263 | } else { | ||
| 5264 | sm2[j*SH + tiisg] = pm2[jj][tiisg]; | ||
| 5265 | } | ||
| 5266 | |||
| 5267 | pm2[jj] += NW; | ||
| 5268 | } | ||
| 5269 | } else if (blk_cur == 2) { | ||
| 5270 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5271 | pm2[jj] += NW; | ||
| 5272 | } | ||
| 5273 | } | ||
| 5274 | |||
| 5275 | #if 0 | ||
| 5276 | // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks | ||
| 5277 | |||
| 5278 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5279 | |||
| 5280 | // used to detect blocks full of -INF | ||
| 5281 | // skip only when the entire threadgroup is masked | ||
| 5282 | half2 smax2(-MAXHALF/2, -MAXHALF/2); | ||
| 5283 | |||
| 5284 | FOR_UNROLL (short j = 0; j < Q; ++j) { | ||
| 5285 | smax2 = max(smax2, sm2[j*SH + tiisg]); | ||
| 5286 | } | ||
| 5287 | |||
| 5288 | smax2 = simd_max(smax2); | ||
| 5289 | |||
| 5290 | if (max(smax2[0], smax2[1]) <= -MAXHALF/2) { | ||
| 5291 | // this barrier is important | ||
| 5292 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5293 | |||
| 5294 | continue; | ||
| 5295 | } | ||
| 5296 | #endif | ||
| 5297 | } | ||
| 5298 | |||
| 5299 | // Q*K^T | ||
| 5300 | // this is compile-time check, so it does not have runtime overhead | ||
| 5301 | if (is_same<kd4x4_t, k4x4_t>::value) { | ||
| 5302 | // we can read directly from global memory | ||
| 5303 | device const k_t * pk = (device const k_t *) (k + ic*args.nb11); | ||
| 5304 | threadgroup const q_t * pq = sq; | ||
| 5305 | threadgroup s_t * ps = ss; | ||
| 5306 | |||
| 5307 | pk += sgitg*(8*NS10); | ||
| 5308 | ps += sgitg*(8*1); | ||
| 5309 | |||
| 5310 | static_assert((C/8) % NSG == 0, ""); | ||
| 5311 | |||
| 5312 | constexpr short NC = (C/8)/NSG; | ||
| 5313 | |||
| 5314 | FOR_UNROLL (short cc = 0; cc < NC; ++cc) { | ||
| 5315 | qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f); | ||
| 5316 | |||
| 5317 | if (DK % 16 != 0) { | ||
| 5318 | k8x8_t mk; | ||
| 5319 | q8x8_t mq; | ||
| 5320 | |||
| 5321 | FOR_UNROLL (short i = 0; i < DK8; ++i) { | ||
| 5322 | simdgroup_barrier(mem_flags::mem_none); | ||
| 5323 | |||
| 5324 | simdgroup_load(mk, pk + 8*i, NS10, 0, true); | ||
| 5325 | simdgroup_load(mq, pq + 8*i, DK); | ||
| 5326 | |||
| 5327 | simdgroup_barrier(mem_flags::mem_none); | ||
| 5328 | |||
| 5329 | simdgroup_multiply_accumulate(mqk, mq, mk, mqk); | ||
| 5330 | } | ||
| 5331 | } else { | ||
| 5332 | k8x8_t mk[2]; | ||
| 5333 | q8x8_t mq[2]; | ||
| 5334 | |||
| 5335 | // note: too much unroll can tank the performance for large heads | ||
| 5336 | #pragma unroll (MIN(DK8/2, 4*NSG)) | ||
| 5337 | for (short i = 0; i < DK8/2; ++i) { | ||
| 5338 | simdgroup_barrier(mem_flags::mem_none); | ||
| 5339 | |||
| 5340 | simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); | ||
| 5341 | simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); | ||
| 5342 | |||
| 5343 | simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); | ||
| 5344 | simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); | ||
| 5345 | |||
| 5346 | simdgroup_barrier(mem_flags::mem_none); | ||
| 5347 | |||
| 5348 | simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); | ||
| 5349 | simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); | ||
| 5350 | } | ||
| 5351 | } | ||
| 5352 | |||
| 5353 | simdgroup_store(mqk, ps, SH, 0, false); | ||
| 5354 | |||
| 5355 | pk += 8*(NSG*NS10); | ||
| 5356 | ps += 8*(NSG); | ||
| 5357 | } | ||
| 5358 | } else { | ||
| 5359 | // TODO: this is the quantized K cache branch - not optimized yet | ||
| 5360 | for (short ccc = 0; ccc < (C/8)/NSG; ++ccc) { | ||
| 5361 | const short cc = ccc*NSG + sgitg; | ||
| 5362 | |||
| 5363 | const short tx = tiisg%4; | ||
| 5364 | const short ty = tiisg/4; | ||
| 5365 | |||
| 5366 | qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f); | ||
| 5367 | |||
| 5368 | for (short ii = 0; ii < DK16; ii += 4) { | ||
| 5369 | device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); | ||
| 5370 | |||
| 5371 | if (DK16%4 == 0) { | ||
| 5372 | // the head is evenly divisible by 4*16 = 64, so no need for bound checks | ||
| 5373 | { | ||
| 5374 | k4x4_t tmp; | ||
| 5375 | deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); | ||
| 5376 | sk4x4[4*ty + tx] = tmp; | ||
| 5377 | } | ||
| 5378 | |||
| 5379 | simdgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5380 | |||
| 5381 | FOR_UNROLL (short k = 0; k < 4; ++k) { | ||
| 5382 | k8x8_t mk; | ||
| 5383 | q8x8_t mq; | ||
| 5384 | |||
| 5385 | simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose | ||
| 5386 | simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); | ||
| 5387 | simdgroup_multiply_accumulate(mqk, mq, mk, mqk); | ||
| 5388 | |||
| 5389 | simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose | ||
| 5390 | simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); | ||
| 5391 | simdgroup_multiply_accumulate(mqk, mq, mk, mqk); | ||
| 5392 | } | ||
| 5393 | } else { | ||
| 5394 | if (ii + tx < DK16) { | ||
| 5395 | k4x4_t tmp; | ||
| 5396 | deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp); | ||
| 5397 | sk4x4[4*ty + tx] = tmp; | ||
| 5398 | } | ||
| 5399 | |||
| 5400 | simdgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5401 | |||
| 5402 | for (short k = 0; k < 4 && ii + k < DK16; ++k) { | ||
| 5403 | k8x8_t mk; | ||
| 5404 | q8x8_t mq; | ||
| 5405 | |||
| 5406 | simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose | ||
| 5407 | simdgroup_load(mq, sq + (2*(ii + k) + 0)*8, DK); | ||
| 5408 | simdgroup_multiply_accumulate(mqk, mq, mk, mqk); | ||
| 5409 | |||
| 5410 | simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose | ||
| 5411 | simdgroup_load(mq, sq + (2*(ii + k) + 1)*8, DK); | ||
| 5412 | simdgroup_multiply_accumulate(mqk, mq, mk, mqk); | ||
| 5413 | } | ||
| 5414 | } | ||
| 5415 | } | ||
| 5416 | |||
| 5417 | simdgroup_store(mqk, ss + 8*cc, SH, 0, false); | ||
| 5418 | } | ||
| 5419 | } | ||
| 5420 | |||
| 5421 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5422 | |||
| 5423 | // online softmax | ||
| 5424 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5425 | const short j = jj*NSG + sgitg; | ||
| 5426 | |||
| 5427 | const float m = M[jj]; | ||
| 5428 | |||
| 5429 | // scale and apply the logitcap / mask | ||
| 5430 | float2 s2 = ss2[j*SH/2 + tiisg]*args.scale; | ||
| 5431 | |||
| 5432 | if (FC_flash_attn_ext_has_scap) { | ||
| 5433 | s2 = args.logit_softcap*precise::tanh(s2); | ||
| 5434 | } | ||
| 5435 | |||
| 5436 | // mqk = mqk + slope*mask | ||
| 5437 | if (blk_cur != 2) { | ||
| 5438 | if (FC_flash_attn_ext_has_bias) { | ||
| 5439 | s2 += s2_t(sm2[j*SH + tiisg])*slope; | ||
| 5440 | } else { | ||
| 5441 | s2 += s2_t(sm2[j*SH + tiisg]); | ||
| 5442 | } | ||
| 5443 | } | ||
| 5444 | |||
| 5445 | M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); | ||
| 5446 | |||
| 5447 | const float ms = exp(m - M[jj]); | ||
| 5448 | const float2 vs2 = exp(s2 - M[jj]); | ||
| 5449 | |||
| 5450 | S[jj] = S[jj]*ms + simd_sum(vs2[0] + vs2[1]); | ||
| 5451 | |||
| 5452 | // the P matrix from the paper (Q rows, C columns) | ||
| 5453 | ss2[j*SH/2 + tiisg] = vs2; | ||
| 5454 | |||
| 5455 | if (DV4 % NW == 0) { | ||
| 5456 | FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { | ||
| 5457 | const short i = ii*NW + tiisg; | ||
| 5458 | |||
| 5459 | so4[j*PV4 + i] *= ms; | ||
| 5460 | } | ||
| 5461 | } else { | ||
| 5462 | for (short i = tiisg; i < DV4; i += NW) { | ||
| 5463 | so4[j*PV4 + i] *= ms; | ||
| 5464 | } | ||
| 5465 | } | ||
| 5466 | } | ||
| 5467 | |||
| 5468 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5469 | |||
| 5470 | // O = O + (Q*K^T)*V | ||
| 5471 | { | ||
| 5472 | // we can read directly from global memory | ||
| 5473 | if (is_same<vd4x4_t, v4x4_t>::value) { | ||
| 5474 | static_assert(PV8 % NSG == 0, ""); | ||
| 5475 | |||
| 5476 | constexpr short NO = PV8/NSG; | ||
| 5477 | |||
| 5478 | o8x8_t lo[NO]; | ||
| 5479 | |||
| 5480 | { | ||
| 5481 | auto sot = so + 8*sgitg; | ||
| 5482 | |||
| 5483 | FOR_UNROLL (short ii = 0; ii < NO; ++ii) { | ||
| 5484 | simdgroup_load(lo[ii], sot, PV, 0, false); | ||
| 5485 | |||
| 5486 | sot += 8*NSG; | ||
| 5487 | } | ||
| 5488 | } | ||
| 5489 | |||
| 5490 | { | ||
| 5491 | device const v_t * pv = (device const v_t *) (v + ic*args.nb21); | ||
| 5492 | |||
| 5493 | pv += 8*sgitg; | ||
| 5494 | |||
| 5495 | if (DV <= 64) { | ||
| 5496 | FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { | ||
| 5497 | s8x8_t vs; | ||
| 5498 | simdgroup_load(vs, ss + 8*cc, SH, 0, false); | ||
| 5499 | |||
| 5500 | FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { | ||
| 5501 | v8x8_t mv[2]; | ||
| 5502 | |||
| 5503 | simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); | ||
| 5504 | simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); | ||
| 5505 | |||
| 5506 | simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); | ||
| 5507 | simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); | ||
| 5508 | } | ||
| 5509 | |||
| 5510 | pv += 8*NS20; | ||
| 5511 | } | ||
| 5512 | } else { | ||
| 5513 | constexpr short NC = (C/8)/2; | ||
| 5514 | |||
| 5515 | FOR_UNROLL (short cc = 0; cc < NC; ++cc) { | ||
| 5516 | s8x8_t vs[2]; | ||
| 5517 | |||
| 5518 | simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); | ||
| 5519 | simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); | ||
| 5520 | |||
| 5521 | FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { | ||
| 5522 | v8x8_t mv[4]; | ||
| 5523 | |||
| 5524 | simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); | ||
| 5525 | simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); | ||
| 5526 | simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); | ||
| 5527 | simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); | ||
| 5528 | |||
| 5529 | simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); | ||
| 5530 | simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); | ||
| 5531 | simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); | ||
| 5532 | simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); | ||
| 5533 | } | ||
| 5534 | |||
| 5535 | pv += 2*8*NS20; | ||
| 5536 | } | ||
| 5537 | } | ||
| 5538 | } | ||
| 5539 | |||
| 5540 | { | ||
| 5541 | auto sot = so + 8*sgitg; | ||
| 5542 | |||
| 5543 | FOR_UNROLL (short ii = 0; ii < NO; ++ii) { | ||
| 5544 | simdgroup_store(lo[ii], sot, PV, 0, false); | ||
| 5545 | |||
| 5546 | sot += 8*NSG; | ||
| 5547 | } | ||
| 5548 | } | ||
| 5549 | } else { | ||
| 5550 | // TODO: this is the quantized V cache branch - not optimized yet | ||
| 5551 | |||
| 5552 | const short tx = tiisg%4; | ||
| 5553 | const short ty = tiisg/4; | ||
| 5554 | |||
| 5555 | for (short cc = 0; cc < C/8; ++cc) { | ||
| 5556 | s8x8_t vs; | ||
| 5557 | simdgroup_load(vs, ss + 8*cc, SH, 0, false); | ||
| 5558 | |||
| 5559 | for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { | ||
| 5560 | device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); | ||
| 5561 | |||
| 5562 | if (DV16%4 == 0) { | ||
| 5563 | // no need for bound checks | ||
| 5564 | { | ||
| 5565 | v4x4_t tmp; | ||
| 5566 | deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); | ||
| 5567 | sv4x4[4*ty + tx] = tmp; | ||
| 5568 | } | ||
| 5569 | |||
| 5570 | simdgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5571 | |||
| 5572 | FOR_UNROLL (short k = 0; k < 4; ++k) { | ||
| 5573 | v8x8_t mv[2]; | ||
| 5574 | o8x8_t lo[2]; | ||
| 5575 | |||
| 5576 | simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); | ||
| 5577 | simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); | ||
| 5578 | simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); | ||
| 5579 | simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); | ||
| 5580 | |||
| 5581 | simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); | ||
| 5582 | simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); | ||
| 5583 | |||
| 5584 | simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); | ||
| 5585 | simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); | ||
| 5586 | } | ||
| 5587 | } else { | ||
| 5588 | if (ii + tx < DV16) { | ||
| 5589 | v4x4_t tmp; | ||
| 5590 | deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp); | ||
| 5591 | sv4x4[4*ty + tx] = tmp; | ||
| 5592 | } | ||
| 5593 | |||
| 5594 | simdgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5595 | |||
| 5596 | for (short k = 0; k < 4 && ii + k < DV16; ++k) { | ||
| 5597 | v8x8_t mv[2]; | ||
| 5598 | o8x8_t lo[2]; | ||
| 5599 | |||
| 5600 | simdgroup_load(mv[0], sv + 16*k + 0*8, 4*16, 0, false); | ||
| 5601 | simdgroup_load(mv[1], sv + 16*k + 1*8, 4*16, 0, false); | ||
| 5602 | simdgroup_load(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); | ||
| 5603 | simdgroup_load(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); | ||
| 5604 | |||
| 5605 | simdgroup_multiply_accumulate(lo[0], vs, mv[0], lo[0]); | ||
| 5606 | simdgroup_multiply_accumulate(lo[1], vs, mv[1], lo[1]); | ||
| 5607 | |||
| 5608 | simdgroup_store(lo[0], so + 8*(2*(ii + k) + 0), PV, 0, false); | ||
| 5609 | simdgroup_store(lo[1], so + 8*(2*(ii + k) + 1), PV, 0, false); | ||
| 5610 | } | ||
| 5611 | } | ||
| 5612 | } | ||
| 5613 | } | ||
| 5614 | } | ||
| 5615 | } | ||
| 5616 | |||
| 5617 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5618 | } | ||
| 5619 | |||
| 5620 | if (FC_flash_attn_ext_has_sinks) { | ||
| 5621 | FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { | ||
| 5622 | const short j = jj*NSG + sgitg; | ||
| 5623 | |||
| 5624 | const float m = M[jj]; | ||
| 5625 | const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; | ||
| 5626 | |||
| 5627 | M[jj] = simd_max(max(M[jj], s)); | ||
| 5628 | |||
| 5629 | const float ms = exp(m - M[jj]); | ||
| 5630 | const float vs = exp(s - M[jj]); | ||
| 5631 | |||
| 5632 | S[jj] = S[jj]*ms + simd_sum(vs); | ||
| 5633 | |||
| 5634 | for (short i = tiisg; i < DV4; i += NW) { | ||
| 5635 | so4[j*PV4 + i] *= ms; | ||
| 5636 | } | ||
| 5637 | } | ||
| 5638 | } | ||
| 5639 | } | ||
| 5640 | |||
| 5641 | // store to global memory | ||
| 5642 | for (short jj = 0; jj < NQ; ++jj) { | ||
| 5643 | const short j = jj*NSG + sgitg; | ||
| 5644 | if (iq1 + j >= args.ne01) { | ||
| 5645 | break; | ||
| 5646 | } | ||
| 5647 | |||
| 5648 | device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; | ||
| 5649 | |||
| 5650 | const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; | ||
| 5651 | |||
| 5652 | if (DV4 % NW == 0) { | ||
| 5653 | FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { | ||
| 5654 | const short i = ii*NW + tiisg; | ||
| 5655 | |||
| 5656 | dst4[i] = (float4) so4[j*PV4 + i]*scale; | ||
| 5657 | } | ||
| 5658 | } else { | ||
| 5659 | for (short i = tiisg; i < DV4; i += NW) { | ||
| 5660 | dst4[i] = (float4) so4[j*PV4 + i]*scale; | ||
| 5661 | } | ||
| 5662 | } | ||
| 5663 | } | ||
| 5664 | |||
| 5665 | #undef NS10 | ||
| 5666 | #undef NS20 | ||
| 5667 | } | ||
| 5668 | |||
| 5669 | template< | ||
| 5670 | typename q_t, // query types in shared memory | ||
| 5671 | typename q4_t, | ||
| 5672 | typename q8x8_t, | ||
| 5673 | typename k_t, // key types in shared memory | ||
| 5674 | typename k4x4_t, | ||
| 5675 | typename k8x8_t, | ||
| 5676 | typename v_t, // value types in shared memory | ||
| 5677 | typename v4x4_t, | ||
| 5678 | typename v8x8_t, | ||
| 5679 | typename qk_t, // Q*K types | ||
| 5680 | typename qk8x8_t, | ||
| 5681 | typename s_t, // soft-max types | ||
| 5682 | typename s2_t, | ||
| 5683 | typename s8x8_t, | ||
| 5684 | typename o_t, // attention accumulation types | ||
| 5685 | typename o4_t, | ||
| 5686 | typename o8x8_t, | ||
| 5687 | typename kd4x4_t, // key type in device memory | ||
| 5688 | short nl_k, | ||
| 5689 | void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &), | ||
| 5690 | typename vd4x4_t, // value type in device memory | ||
| 5691 | short nl_v, | ||
| 5692 | void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), | ||
| 5693 | short DK, // K head size | ||
| 5694 | short DV, // V head size | ||
| 5695 | short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup | ||
| 5696 | short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup | ||
| 5697 | kernel void kernel_flash_attn_ext( | ||
| 5698 | constant ggml_metal_kargs_flash_attn_ext & args, | ||
| 5699 | device const char * q, | ||
| 5700 | device const char * k, | ||
| 5701 | device const char * v, | ||
| 5702 | device const char * mask, | ||
| 5703 | device const char * sinks, | ||
| 5704 | device const char * pad, | ||
| 5705 | device const char * blk, | ||
| 5706 | device char * dst, | ||
| 5707 | threadgroup half * shmem_f16 [[threadgroup(0)]], | ||
| 5708 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 5709 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 5710 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 5711 | #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C | ||
| 5712 | #define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg | ||
| 5713 | switch (FC_flash_attn_ext_nsg) { | ||
| 5714 | // note: disabled cases to reduce library load time | ||
| 5715 | //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break; | ||
| 5716 | //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break; | ||
| 5717 | case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break; | ||
| 5718 | case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break; | ||
| 5719 | } | ||
| 5720 | #undef FWD_TMPL | ||
| 5721 | #undef FWD_ARGS | ||
| 5722 | } | ||
| 5723 | |||
| 5724 | // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as | ||
| 5725 | // template to be able to explore different combinations | ||
| 5726 | // | ||
| 5727 | #define FA_TYPES \ | ||
| 5728 | half, half4, simdgroup_half8x8, \ | ||
| 5729 | half, half4x4, simdgroup_half8x8, \ | ||
| 5730 | half, half4x4, simdgroup_half8x8, \ | ||
| 5731 | float, simdgroup_float8x8, \ | ||
| 5732 | float, float2, simdgroup_float8x8, \ | ||
| 5733 | float, float4, simdgroup_float8x8 | ||
| 5734 | //half, half4, simdgroup_half8x8 | ||
| 5735 | |||
| 5736 | #define FA_TYPES_BF \ | ||
| 5737 | bfloat, bfloat4, simdgroup_bfloat8x8, \ | ||
| 5738 | bfloat, bfloat4x4, simdgroup_bfloat8x8, \ | ||
| 5739 | bfloat, bfloat4x4, simdgroup_bfloat8x8, \ | ||
| 5740 | float, simdgroup_float8x8, \ | ||
| 5741 | float, float2, simdgroup_float8x8, \ | ||
| 5742 | half, half4, simdgroup_half8x8 | ||
| 5743 | //float, float4, simdgroup_float8x8 | ||
| 5744 | |||
| 5745 | #define FA_TYPES_F32 \ | ||
| 5746 | half, half4, simdgroup_half8x8, \ | ||
| 5747 | float, float4x4, simdgroup_float8x8, \ | ||
| 5748 | float, float4x4, simdgroup_float8x8, \ | ||
| 5749 | float, simdgroup_float8x8, \ | ||
| 5750 | float, float2, simdgroup_float8x8, \ | ||
| 5751 | float, float4, simdgroup_float8x8 | ||
| 5752 | //half, half4, simdgroup_half8x8 | ||
| 5753 | |||
| 5754 | typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t; | ||
| 5755 | |||
| 5756 | template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>; | ||
| 5757 | template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>; | ||
| 5758 | template [[host_name("kernel_flash_attn_ext_f32_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 48, 48>; | ||
| 5759 | template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>; | ||
| 5760 | template [[host_name("kernel_flash_attn_ext_f32_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 72, 72>; | ||
| 5761 | template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>; | ||
| 5762 | template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>; | ||
| 5763 | template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>; | ||
| 5764 | template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>; | ||
| 5765 | template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>; | ||
| 5766 | template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>; | ||
| 5767 | template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>; | ||
| 5768 | template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>; | ||
| 5769 | |||
| 5770 | template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>; | ||
| 5771 | template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>; | ||
| 5772 | template [[host_name("kernel_flash_attn_ext_f16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 48, 48>; | ||
| 5773 | template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>; | ||
| 5774 | template [[host_name("kernel_flash_attn_ext_f16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 72, 72>; | ||
| 5775 | template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>; | ||
| 5776 | template [[host_name("kernel_flash_attn_ext_f16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96, 96>; | ||
| 5777 | template [[host_name("kernel_flash_attn_ext_f16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112, 112>; | ||
| 5778 | template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128, 128>; | ||
| 5779 | template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>; | ||
| 5780 | template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>; | ||
| 5781 | template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>; | ||
| 5782 | template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>; | ||
| 5783 | |||
| 5784 | #if defined(GGML_METAL_HAS_BF16) | ||
| 5785 | template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>; | ||
| 5786 | template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>; | ||
| 5787 | template [[host_name("kernel_flash_attn_ext_bf16_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 48, 48>; | ||
| 5788 | template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>; | ||
| 5789 | template [[host_name("kernel_flash_attn_ext_bf16_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 72, 72>; | ||
| 5790 | template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>; | ||
| 5791 | template [[host_name("kernel_flash_attn_ext_bf16_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>; | ||
| 5792 | template [[host_name("kernel_flash_attn_ext_bf16_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>; | ||
| 5793 | template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>; | ||
| 5794 | template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>; | ||
| 5795 | template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>; | ||
| 5796 | template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>; | ||
| 5797 | template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>; | ||
| 5798 | #endif | ||
| 5799 | |||
| 5800 | template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>; | ||
| 5801 | template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>; | ||
| 5802 | template [[host_name("kernel_flash_attn_ext_q4_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 48, 48>; | ||
| 5803 | template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>; | ||
| 5804 | template [[host_name("kernel_flash_attn_ext_q4_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 72, 72>; | ||
| 5805 | template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>; | ||
| 5806 | template [[host_name("kernel_flash_attn_ext_q4_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96, 96>; | ||
| 5807 | template [[host_name("kernel_flash_attn_ext_q4_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112, 112>; | ||
| 5808 | template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128, 128>; | ||
| 5809 | template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>; | ||
| 5810 | template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>; | ||
| 5811 | template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>; | ||
| 5812 | template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>; | ||
| 5813 | |||
| 5814 | template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>; | ||
| 5815 | template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>; | ||
| 5816 | template [[host_name("kernel_flash_attn_ext_q4_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 48, 48>; | ||
| 5817 | template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>; | ||
| 5818 | template [[host_name("kernel_flash_attn_ext_q4_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 72, 72>; | ||
| 5819 | template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>; | ||
| 5820 | template [[host_name("kernel_flash_attn_ext_q4_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96, 96>; | ||
| 5821 | template [[host_name("kernel_flash_attn_ext_q4_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112, 112>; | ||
| 5822 | template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128, 128>; | ||
| 5823 | template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>; | ||
| 5824 | template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>; | ||
| 5825 | template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>; | ||
| 5826 | template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>; | ||
| 5827 | |||
| 5828 | template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>; | ||
| 5829 | template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>; | ||
| 5830 | template [[host_name("kernel_flash_attn_ext_q5_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 48, 48>; | ||
| 5831 | template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>; | ||
| 5832 | template [[host_name("kernel_flash_attn_ext_q5_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 72, 72>; | ||
| 5833 | template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>; | ||
| 5834 | template [[host_name("kernel_flash_attn_ext_q5_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96, 96>; | ||
| 5835 | template [[host_name("kernel_flash_attn_ext_q5_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112, 112>; | ||
| 5836 | template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128, 128>; | ||
| 5837 | template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>; | ||
| 5838 | template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>; | ||
| 5839 | template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>; | ||
| 5840 | template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>; | ||
| 5841 | |||
| 5842 | template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>; | ||
| 5843 | template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>; | ||
| 5844 | template [[host_name("kernel_flash_attn_ext_q5_1_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 48, 48>; | ||
| 5845 | template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>; | ||
| 5846 | template [[host_name("kernel_flash_attn_ext_q5_1_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 72, 72>; | ||
| 5847 | template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>; | ||
| 5848 | template [[host_name("kernel_flash_attn_ext_q5_1_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96, 96>; | ||
| 5849 | template [[host_name("kernel_flash_attn_ext_q5_1_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112, 112>; | ||
| 5850 | template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128, 128>; | ||
| 5851 | template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>; | ||
| 5852 | template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>; | ||
| 5853 | template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>; | ||
| 5854 | template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>; | ||
| 5855 | |||
| 5856 | template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>; | ||
| 5857 | template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>; | ||
| 5858 | template [[host_name("kernel_flash_attn_ext_q8_0_dk48_dv48" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 48, 48>; | ||
| 5859 | template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>; | ||
| 5860 | template [[host_name("kernel_flash_attn_ext_q8_0_dk72_dv72" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 72, 72>; | ||
| 5861 | template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>; | ||
| 5862 | template [[host_name("kernel_flash_attn_ext_q8_0_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96, 96>; | ||
| 5863 | template [[host_name("kernel_flash_attn_ext_q8_0_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112, 112>; | ||
| 5864 | template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128, 128>; | ||
| 5865 | template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>; | ||
| 5866 | template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>; | ||
| 5867 | template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>; | ||
| 5868 | template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>; | ||
| 5869 | |||
| 5870 | #undef FA_TYPES | ||
| 5871 | #undef FA_TYPES_BF | ||
| 5872 | #undef FA_TYPES_F32 | ||
| 5873 | |||
| 5874 | constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; | ||
| 5875 | constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; | ||
| 5876 | constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; | ||
| 5877 | constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; | ||
| 5878 | constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; | ||
| 5879 | |||
| 5880 | //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; | ||
| 5881 | //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; | ||
| 5882 | //constant float FC_flash_attn_ext_vec_logit_softcap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 12)]]; | ||
| 5883 | |||
| 5884 | constant int32_t FC_flash_attn_ext_vec_ns10 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 20)]]; | ||
| 5885 | constant int32_t FC_flash_attn_ext_vec_ns20 [[function_constant(FC_FLASH_ATTN_EXT_VEC + 21)]]; | ||
| 5886 | constant int32_t FC_flash_attn_ext_vec_nsg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 22)]]; | ||
| 5887 | constant int32_t FC_flash_attn_ext_vec_nwg [[function_constant(FC_FLASH_ATTN_EXT_VEC + 23)]]; | ||
| 5888 | |||
| 5889 | template< | ||
| 5890 | typename q4_t, // query types in shared memory | ||
| 5891 | typename k4_t, // key types in shared memory | ||
| 5892 | typename v4_t, // value types in shared memory | ||
| 5893 | typename qk_t, // Q*K types | ||
| 5894 | typename s_t, // soft-max types | ||
| 5895 | typename s4_t, | ||
| 5896 | typename o4_t, // attention accumulation types | ||
| 5897 | typename kd4_t, // key type in device memory | ||
| 5898 | short nl_k, | ||
| 5899 | void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), | ||
| 5900 | typename vd4_t, // value type in device memory | ||
| 5901 | short nl_v, | ||
| 5902 | void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), | ||
| 5903 | short DK, // K head size | ||
| 5904 | short DV, // V head size | ||
| 5905 | short NE = 4, // head elements per thread | ||
| 5906 | short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup | ||
| 5907 | short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup | ||
| 5908 | kernel void kernel_flash_attn_ext_vec( | ||
| 5909 | constant ggml_metal_kargs_flash_attn_ext_vec & args, | ||
| 5910 | device const char * q, | ||
| 5911 | device const char * k, | ||
| 5912 | device const char * v, | ||
| 5913 | device const char * mask, | ||
| 5914 | device const char * sinks, | ||
| 5915 | device const char * pad, | ||
| 5916 | device char * dst, | ||
| 5917 | threadgroup half * shmem_f16 [[threadgroup(0)]], | ||
| 5918 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 5919 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 5920 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 5921 | static_assert(DK % 32 == 0, "DK must be divisible by 32"); | ||
| 5922 | static_assert(DV % 32 == 0, "DV must be divisible by 32"); | ||
| 5923 | |||
| 5924 | #define NWG (FC_flash_attn_ext_vec_nwg) | ||
| 5925 | #define NSG (FC_flash_attn_ext_vec_nsg) | ||
| 5926 | |||
| 5927 | #define NS10 (FC_flash_attn_ext_vec_ns10) | ||
| 5928 | #define NS20 (FC_flash_attn_ext_vec_ns20) | ||
| 5929 | |||
| 5930 | const short iwg = tgpig[2]%NWG; | ||
| 5931 | |||
| 5932 | const ushort iq3 = tgpig[2]/NWG; | ||
| 5933 | const ushort iq2 = tgpig[1]; | ||
| 5934 | const ushort iq1 = tgpig[0]; | ||
| 5935 | |||
| 5936 | constexpr short DK4 = DK/4; | ||
| 5937 | constexpr short DV4 = DV/4; | ||
| 5938 | |||
| 5939 | constexpr short PK = PAD2(DK, 128); | ||
| 5940 | constexpr short PK4 = PK/4; | ||
| 5941 | |||
| 5942 | constexpr short PV = PAD2(DV, 128); | ||
| 5943 | constexpr short PV4 = PV/4; | ||
| 5944 | |||
| 5945 | constexpr short NW = N_SIMDWIDTH; | ||
| 5946 | constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads | ||
| 5947 | constexpr short SH = 4*C; // shared memory per simdgroup | ||
| 5948 | |||
| 5949 | static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); | ||
| 5950 | static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); | ||
| 5951 | |||
| 5952 | const short T = PK + NSG*SH; // shared memory size per query in (half) | ||
| 5953 | |||
| 5954 | //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data | ||
| 5955 | threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t | ||
| 5956 | threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention | ||
| 5957 | threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t | ||
| 5958 | threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask | ||
| 5959 | threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results | ||
| 5960 | |||
| 5961 | // store the result for all queries in shared memory (the O matrix from the paper) | ||
| 5962 | so4 += tiisg; | ||
| 5963 | |||
| 5964 | { | ||
| 5965 | q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; | ||
| 5966 | |||
| 5967 | const short ikv2 = iq2/(args.ne02/args.ne_12_2); | ||
| 5968 | const short ikv3 = iq3/(args.ne03/args.ne_12_3); | ||
| 5969 | |||
| 5970 | k += ikv2*args.nb12 + ikv3*args.nb13; | ||
| 5971 | v += ikv2*args.nb22 + ikv3*args.nb23; | ||
| 5972 | } | ||
| 5973 | |||
| 5974 | // load heads from Q to shared memory | ||
| 5975 | device const float4 * q4 = (device const float4 *) ((device const char *) q); | ||
| 5976 | |||
| 5977 | if (iq1 < args.ne01) { | ||
| 5978 | for (short i = tiisg; i < PK4; i += NW) { | ||
| 5979 | if (i < DK4) { | ||
| 5980 | sq4[i] = (q4_t) q4[i]; | ||
| 5981 | } else { | ||
| 5982 | sq4[i] = (q4_t) 0.0f; | ||
| 5983 | } | ||
| 5984 | } | ||
| 5985 | } | ||
| 5986 | |||
| 5987 | // zero out so | ||
| 5988 | for (short i = 0; i < DV4/NL; ++i) { | ||
| 5989 | so4[i*NL] = (o4_t) 0.0f; | ||
| 5990 | } | ||
| 5991 | |||
| 5992 | // zero out shared memory SH | ||
| 5993 | for (short i = tiisg; i < SH/4; i += NW) { | ||
| 5994 | ss4[i] = (s4_t) 0.0f; | ||
| 5995 | } | ||
| 5996 | |||
| 5997 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 5998 | |||
| 5999 | { | ||
| 6000 | float S = 0.0f; | ||
| 6001 | float M = -FLT_MAX/2; | ||
| 6002 | |||
| 6003 | // thread indices inside the simdgroup | ||
| 6004 | const short tx = tiisg%NL; | ||
| 6005 | const short ty = tiisg/NL; | ||
| 6006 | |||
| 6007 | // pointer to the mask | ||
| 6008 | device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); | ||
| 6009 | |||
| 6010 | float slope = 1.0f; | ||
| 6011 | |||
| 6012 | // ALiBi | ||
| 6013 | if (FC_flash_attn_ext_vec_has_bias) { | ||
| 6014 | const short h = iq2; | ||
| 6015 | |||
| 6016 | const float base = h < args.n_head_log2 ? args.m0 : args.m1; | ||
| 6017 | const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; | ||
| 6018 | |||
| 6019 | slope = pow(base, exph); | ||
| 6020 | } | ||
| 6021 | |||
| 6022 | // loop over the KV cache | ||
| 6023 | // each simdgroup handles blocks of Q rows and C columns | ||
| 6024 | for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { | ||
| 6025 | int ic = ic0*C; | ||
| 6026 | if (ic >= args.ne11) { | ||
| 6027 | break; | ||
| 6028 | } | ||
| 6029 | |||
| 6030 | // the last partial chunk uses the pad buffer as source | ||
| 6031 | if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { | ||
| 6032 | k = pad; | ||
| 6033 | v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; | ||
| 6034 | mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; | ||
| 6035 | |||
| 6036 | const short ikv2 = iq2/(args.ne02/args.ne_12_2); | ||
| 6037 | const short ikv3 = iq3/(args.ne03/args.ne_12_3); | ||
| 6038 | |||
| 6039 | k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; | ||
| 6040 | v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; | ||
| 6041 | |||
| 6042 | if (!FC_flash_attn_ext_vec_has_mask) { | ||
| 6043 | if (ic + tiisg >= args.ne11) { | ||
| 6044 | sm[tiisg] = -MAXHALF; | ||
| 6045 | } | ||
| 6046 | } else { | ||
| 6047 | pm = (device const half *) (mask) + | ||
| 6048 | iq1*C + | ||
| 6049 | (iq2%args.ne32)*(C*args.ne31) + | ||
| 6050 | (iq3%args.ne33)*(C*args.ne31*args.ne32); | ||
| 6051 | } | ||
| 6052 | |||
| 6053 | ic = 0; | ||
| 6054 | } | ||
| 6055 | |||
| 6056 | if (FC_flash_attn_ext_vec_has_mask) { | ||
| 6057 | sm[tiisg] = pm[ic + tiisg]; | ||
| 6058 | } | ||
| 6059 | |||
| 6060 | // skip -INF blocks | ||
| 6061 | if (simd_max(sm[tiisg]) <= -MAXHALF) { | ||
| 6062 | continue; | ||
| 6063 | } | ||
| 6064 | |||
| 6065 | // Q*K^T | ||
| 6066 | { | ||
| 6067 | device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); | ||
| 6068 | threadgroup const q4_t * pq4 = sq4; | ||
| 6069 | |||
| 6070 | pk4 += ty*NS10/4 + tx; | ||
| 6071 | pq4 += tx; | ||
| 6072 | |||
| 6073 | qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; | ||
| 6074 | |||
| 6075 | // each simdgroup processes 1 query and NE (NW/NL) cache elements | ||
| 6076 | FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { | ||
| 6077 | if (is_same<kd4_t, k4_t>::value) { | ||
| 6078 | FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { | ||
| 6079 | mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); | ||
| 6080 | } | ||
| 6081 | } else { | ||
| 6082 | device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); | ||
| 6083 | |||
| 6084 | k4_t mk; | ||
| 6085 | |||
| 6086 | FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { | ||
| 6087 | const short i = ii*NL + tx; | ||
| 6088 | |||
| 6089 | deq_k_t4(pk + i/nl_k, i%nl_k, mk); | ||
| 6090 | |||
| 6091 | mqk[cc] += dot((float4) mk, (float4) sq4[i]); | ||
| 6092 | } | ||
| 6093 | } | ||
| 6094 | |||
| 6095 | if (NE == 1) { | ||
| 6096 | mqk[cc] = simd_sum(mqk[cc]); | ||
| 6097 | } else { | ||
| 6098 | // simdgroup reduce (NE = 4) | ||
| 6099 | // [ 0 .. 7] -> [ 0] | ||
| 6100 | // [ 8 .. 15] -> [ 8] | ||
| 6101 | // [16 .. 23] -> [16] | ||
| 6102 | // [24 .. 31] -> [24] | ||
| 6103 | if (NE <= 1) { | ||
| 6104 | mqk[cc] += simd_shuffle_down(mqk[cc], 16); | ||
| 6105 | } | ||
| 6106 | if (NE <= 2) { | ||
| 6107 | mqk[cc] += simd_shuffle_down(mqk[cc], 8); | ||
| 6108 | } | ||
| 6109 | if (NE <= 4) { | ||
| 6110 | mqk[cc] += simd_shuffle_down(mqk[cc], 4); | ||
| 6111 | } | ||
| 6112 | if (NE <= 8) { | ||
| 6113 | mqk[cc] += simd_shuffle_down(mqk[cc], 2); | ||
| 6114 | } | ||
| 6115 | if (NE <= 16) { | ||
| 6116 | mqk[cc] += simd_shuffle_down(mqk[cc], 1); | ||
| 6117 | } | ||
| 6118 | |||
| 6119 | // broadcast | ||
| 6120 | mqk[cc] = simd_shuffle(mqk[cc], NL*ty); | ||
| 6121 | } | ||
| 6122 | } | ||
| 6123 | |||
| 6124 | if (FC_flash_attn_ext_vec_has_mask && | ||
| 6125 | !FC_flash_attn_ext_vec_has_scap && | ||
| 6126 | !FC_flash_attn_ext_vec_has_bias) { | ||
| 6127 | ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); | ||
| 6128 | } else { | ||
| 6129 | mqk[tx] *= args.scale; | ||
| 6130 | |||
| 6131 | if (FC_flash_attn_ext_vec_has_scap) { | ||
| 6132 | mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); | ||
| 6133 | } | ||
| 6134 | |||
| 6135 | if (FC_flash_attn_ext_vec_has_bias) { | ||
| 6136 | mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; | ||
| 6137 | } else { | ||
| 6138 | mqk[tx] += (qk_t) sm[NE*tx + ty]; | ||
| 6139 | } | ||
| 6140 | |||
| 6141 | ss[NE*tx + ty] = mqk[tx]; | ||
| 6142 | } | ||
| 6143 | } | ||
| 6144 | |||
| 6145 | simdgroup_barrier(mem_flags::mem_threadgroup); | ||
| 6146 | |||
| 6147 | // online softmax | ||
| 6148 | { | ||
| 6149 | const float m = M; | ||
| 6150 | const float s = ss[tiisg]; | ||
| 6151 | |||
| 6152 | M = simd_max(max(M, s)); | ||
| 6153 | |||
| 6154 | const float ms = exp(m - M); | ||
| 6155 | const float vs = exp(s - M); | ||
| 6156 | |||
| 6157 | S = S*ms + simd_sum(vs); | ||
| 6158 | |||
| 6159 | // the P matrix from the paper (Q rows, C columns) | ||
| 6160 | ss[tiisg] = vs; | ||
| 6161 | |||
| 6162 | // O = diag(ms)*O | ||
| 6163 | if ((DV4/NL % NW == 0) || ty == 0) { | ||
| 6164 | FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { | ||
| 6165 | so4[ii*NL] *= ms; | ||
| 6166 | } | ||
| 6167 | } | ||
| 6168 | } | ||
| 6169 | |||
| 6170 | simdgroup_barrier(mem_flags::mem_threadgroup); | ||
| 6171 | |||
| 6172 | // O = O + (Q*K^T)*V | ||
| 6173 | { | ||
| 6174 | o4_t lo[DV4/NL]; | ||
| 6175 | FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { | ||
| 6176 | lo[ii] = 0.0f; | ||
| 6177 | } | ||
| 6178 | |||
| 6179 | if (is_same<vd4_t, v4_t>::value) { | ||
| 6180 | device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); | ||
| 6181 | |||
| 6182 | pv4 += ty*NS20/4 + tx; | ||
| 6183 | |||
| 6184 | const auto sst = ss + ty; | ||
| 6185 | |||
| 6186 | FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { | ||
| 6187 | FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { | ||
| 6188 | lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); | ||
| 6189 | } | ||
| 6190 | } | ||
| 6191 | } else { | ||
| 6192 | FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { | ||
| 6193 | device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); | ||
| 6194 | |||
| 6195 | FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { | ||
| 6196 | const short i = ii*NL + tx; | ||
| 6197 | |||
| 6198 | v4_t mv; | ||
| 6199 | deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); | ||
| 6200 | |||
| 6201 | lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); | ||
| 6202 | } | ||
| 6203 | } | ||
| 6204 | } | ||
| 6205 | |||
| 6206 | FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { | ||
| 6207 | if (NE > 1) { | ||
| 6208 | lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); | ||
| 6209 | lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); | ||
| 6210 | lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); | ||
| 6211 | lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); | ||
| 6212 | } | ||
| 6213 | |||
| 6214 | if (NE > 2) { | ||
| 6215 | lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); | ||
| 6216 | lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); | ||
| 6217 | lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); | ||
| 6218 | lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); | ||
| 6219 | } | ||
| 6220 | |||
| 6221 | if (NE > 4) { | ||
| 6222 | lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); | ||
| 6223 | lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); | ||
| 6224 | lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); | ||
| 6225 | lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); | ||
| 6226 | } | ||
| 6227 | |||
| 6228 | if (NE > 8) { | ||
| 6229 | lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); | ||
| 6230 | lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); | ||
| 6231 | lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); | ||
| 6232 | lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); | ||
| 6233 | } | ||
| 6234 | |||
| 6235 | if (NE > 16) { | ||
| 6236 | lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); | ||
| 6237 | lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); | ||
| 6238 | lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); | ||
| 6239 | lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); | ||
| 6240 | } | ||
| 6241 | } | ||
| 6242 | |||
| 6243 | if ((DV4/NL % NW == 0) || ty == 0) { | ||
| 6244 | FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { | ||
| 6245 | so4[ii*NL] += lo[ii]; | ||
| 6246 | } | ||
| 6247 | } | ||
| 6248 | } | ||
| 6249 | } | ||
| 6250 | |||
| 6251 | if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { | ||
| 6252 | const float m = M; | ||
| 6253 | const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; | ||
| 6254 | |||
| 6255 | M = simd_max(max(M, s)); | ||
| 6256 | |||
| 6257 | const float ms = exp(m - M); | ||
| 6258 | const float vs = exp(s - M); | ||
| 6259 | |||
| 6260 | S = S*ms + simd_sum(vs); | ||
| 6261 | |||
| 6262 | if ((DV4/NL % NW == 0) || ty == 0) { | ||
| 6263 | FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { | ||
| 6264 | so4[ii*NL] *= ms; | ||
| 6265 | } | ||
| 6266 | } | ||
| 6267 | } | ||
| 6268 | |||
| 6269 | // these are needed for reducing the results from the simdgroups (reuse the ss buffer) | ||
| 6270 | if (tiisg == 0) { | ||
| 6271 | ss[0] = (s_t) S; | ||
| 6272 | ss[1] = (s_t) M; | ||
| 6273 | } | ||
| 6274 | } | ||
| 6275 | |||
| 6276 | so4 -= tiisg; | ||
| 6277 | |||
| 6278 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 6279 | |||
| 6280 | // parallel reduce | ||
| 6281 | for (short r = NSG/2; r > 0; r >>= 1) { | ||
| 6282 | if (sgitg < r) { | ||
| 6283 | const float S0 = ss[ 0]; | ||
| 6284 | const float S1 = ss[r*(SH/2) + 0]; | ||
| 6285 | |||
| 6286 | const float M0 = ss[ 1]; | ||
| 6287 | const float M1 = ss[r*(SH/2) + 1]; | ||
| 6288 | |||
| 6289 | const float M = max(M0, M1); | ||
| 6290 | |||
| 6291 | const float ms0 = exp(M0 - M); | ||
| 6292 | const float ms1 = exp(M1 - M); | ||
| 6293 | |||
| 6294 | const float S = S0*ms0 + S1*ms1; | ||
| 6295 | |||
| 6296 | if (tiisg == 0) { | ||
| 6297 | ss[0] = S; | ||
| 6298 | ss[1] = M; | ||
| 6299 | } | ||
| 6300 | |||
| 6301 | // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 | ||
| 6302 | for (short i = tiisg; i < DV4; i += NW) { | ||
| 6303 | so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; | ||
| 6304 | } | ||
| 6305 | } | ||
| 6306 | |||
| 6307 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 6308 | } | ||
| 6309 | |||
| 6310 | // final rescale with 1/S and store to global memory | ||
| 6311 | if (sgitg == 0) { | ||
| 6312 | const int64_t nrows = args.ne3*args.ne2*args.ne1; | ||
| 6313 | const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; | ||
| 6314 | |||
| 6315 | device float4 * dst4 = (device float4 *) dst; | ||
| 6316 | device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results | ||
| 6317 | |||
| 6318 | const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; | ||
| 6319 | |||
| 6320 | // interleave the workgroup data | ||
| 6321 | for (short i = tiisg; i < DV4; i += NW) { | ||
| 6322 | dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; | ||
| 6323 | } | ||
| 6324 | |||
| 6325 | // store S and M | ||
| 6326 | if (NWG > 1) { | ||
| 6327 | if (tiisg == 0) { | ||
| 6328 | dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; | ||
| 6329 | dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; | ||
| 6330 | } | ||
| 6331 | } | ||
| 6332 | } | ||
| 6333 | |||
| 6334 | #undef NWG | ||
| 6335 | #undef NSG | ||
| 6336 | #undef NS10 | ||
| 6337 | #undef NS20 | ||
| 6338 | } | ||
| 6339 | |||
| 6340 | // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem | ||
| 6341 | // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max | ||
| 6342 | // | ||
| 6343 | #define FA_TYPES \ | ||
| 6344 | half4, \ | ||
| 6345 | half4, \ | ||
| 6346 | half4, \ | ||
| 6347 | float, \ | ||
| 6348 | float, float4, \ | ||
| 6349 | float4 | ||
| 6350 | |||
| 6351 | #define FA_TYPES_F32 \ | ||
| 6352 | half4, \ | ||
| 6353 | float4, \ | ||
| 6354 | float4, \ | ||
| 6355 | float, \ | ||
| 6356 | float, float4, \ | ||
| 6357 | float4 | ||
| 6358 | |||
| 6359 | typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; | ||
| 6360 | |||
| 6361 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>; | ||
| 6362 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>; | ||
| 6363 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6364 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 32, 32, 4>; | ||
| 6365 | #endif | ||
| 6366 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>; | ||
| 6367 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>; | ||
| 6368 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>; | ||
| 6369 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>; | ||
| 6370 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>; | ||
| 6371 | |||
| 6372 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>; | ||
| 6373 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>; | ||
| 6374 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6375 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 2>; | ||
| 6376 | #endif | ||
| 6377 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>; | ||
| 6378 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>; | ||
| 6379 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>; | ||
| 6380 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>; | ||
| 6381 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>; | ||
| 6382 | |||
| 6383 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>; | ||
| 6384 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>; | ||
| 6385 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6386 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 96, 96, 4>; | ||
| 6387 | #endif | ||
| 6388 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>; | ||
| 6389 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>; | ||
| 6390 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>; | ||
| 6391 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>; | ||
| 6392 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>; | ||
| 6393 | |||
| 6394 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>; | ||
| 6395 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>; | ||
| 6396 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6397 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 128, 128, 1>; | ||
| 6398 | #endif | ||
| 6399 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>; | ||
| 6400 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>; | ||
| 6401 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>; | ||
| 6402 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>; | ||
| 6403 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>; | ||
| 6404 | |||
| 6405 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>; | ||
| 6406 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>; | ||
| 6407 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6408 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 192, 2>; | ||
| 6409 | #endif | ||
| 6410 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>; | ||
| 6411 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>; | ||
| 6412 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>; | ||
| 6413 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>; | ||
| 6414 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>; | ||
| 6415 | |||
| 6416 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>; | ||
| 6417 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>; | ||
| 6418 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6419 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 128, 2>; | ||
| 6420 | #endif | ||
| 6421 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>; | ||
| 6422 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>; | ||
| 6423 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>; | ||
| 6424 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>; | ||
| 6425 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>; | ||
| 6426 | |||
| 6427 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>; | ||
| 6428 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>; | ||
| 6429 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6430 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>; | ||
| 6431 | #endif | ||
| 6432 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>; | ||
| 6433 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>; | ||
| 6434 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>; | ||
| 6435 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>; | ||
| 6436 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>; | ||
| 6437 | |||
| 6438 | template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>; | ||
| 6439 | template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>; | ||
| 6440 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6441 | template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>; | ||
| 6442 | #endif | ||
| 6443 | template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>; | ||
| 6444 | template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>; | ||
| 6445 | template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>; | ||
| 6446 | template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>; | ||
| 6447 | template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>; | ||
| 6448 | |||
| 6449 | #undef FA_TYPES | ||
| 6450 | #undef FA_TYPES_F32 | ||
| 6451 | |||
| 6452 | constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; | ||
| 6453 | constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; | ||
| 6454 | |||
| 6455 | kernel void kernel_flash_attn_ext_vec_reduce( | ||
| 6456 | constant ggml_metal_kargs_flash_attn_ext_vec_reduce & args, | ||
| 6457 | device const char * htmp, | ||
| 6458 | device char * dst, | ||
| 6459 | uint tgpig[[threadgroup_position_in_grid]], | ||
| 6460 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 6461 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 6462 | #define NWG (FC_flash_attn_ext_vec_reduce_NWG) | ||
| 6463 | #define DV (FC_flash_attn_ext_vec_reduce_DV) | ||
| 6464 | |||
| 6465 | const uint64_t rid = tgpig; | ||
| 6466 | |||
| 6467 | const short iwg = tiisg; | ||
| 6468 | |||
| 6469 | device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*NWG; | ||
| 6470 | |||
| 6471 | float S = ss[rid*(2*NWG) + 2*iwg + 0]; | ||
| 6472 | float M = ss[rid*(2*NWG) + 2*iwg + 1]; | ||
| 6473 | |||
| 6474 | const float m = simd_max(M); | ||
| 6475 | const float ms = exp(M - m); | ||
| 6476 | |||
| 6477 | S = simd_sum(S*ms); | ||
| 6478 | S = S == 0.0f ? 0.0f : 1.0f/S; | ||
| 6479 | |||
| 6480 | const short DV4 = DV/4; | ||
| 6481 | |||
| 6482 | device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*NWG; | ||
| 6483 | device float4 * dst4 = (device float4 *) dst + rid*DV4; | ||
| 6484 | |||
| 6485 | for (short i = sgitg; i < DV4; i += NWG) { | ||
| 6486 | const float4 v = simd_sum(htmp4[i*NWG + iwg]*ms); | ||
| 6487 | |||
| 6488 | if (iwg == 0) { | ||
| 6489 | dst4[i] = v*S; | ||
| 6490 | } | ||
| 6491 | } | ||
| 6492 | |||
| 6493 | #undef NWG | ||
| 6494 | #undef DV | ||
| 6495 | } | ||
| 6496 | |||
| 6497 | template<typename T0, typename T1> | ||
| 6498 | kernel void kernel_cpy_t_t( | ||
| 6499 | constant ggml_metal_kargs_cpy & args, | ||
| 6500 | device const char * src0, | ||
| 6501 | device char * dst, | ||
| 6502 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 6503 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 6504 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 6505 | const int i03 = tgpig[2]; | ||
| 6506 | const int i02 = tgpig[1]; | ||
| 6507 | const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; | ||
| 6508 | const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; | ||
| 6509 | |||
| 6510 | const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; | ||
| 6511 | |||
| 6512 | const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); | ||
| 6513 | const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); | ||
| 6514 | const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; | ||
| 6515 | const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); | ||
| 6516 | |||
| 6517 | device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 6518 | |||
| 6519 | for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { | ||
| 6520 | device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); | ||
| 6521 | dst_data[i00] = (T1) src[0]; | ||
| 6522 | break; | ||
| 6523 | } | ||
| 6524 | } | ||
| 6525 | |||
| 6526 | typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t; | ||
| 6527 | |||
| 6528 | template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>; | ||
| 6529 | template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>; | ||
| 6530 | template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>; | ||
| 6531 | template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>; | ||
| 6532 | template [[host_name("kernel_cpy_i32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, int32_t>; | ||
| 6533 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6534 | template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>; | ||
| 6535 | #endif | ||
| 6536 | template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>; | ||
| 6537 | template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>; | ||
| 6538 | #if defined(GGML_METAL_HAS_BF16) | ||
| 6539 | template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>; | ||
| 6540 | template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>; | ||
| 6541 | #endif | ||
| 6542 | |||
| 6543 | template<short QK, | ||
| 6544 | typename block_q, | ||
| 6545 | void (*quantize_func)(device const float *, device block_q &)> | ||
| 6546 | kernel void kernel_cpy_f32_q( | ||
| 6547 | constant ggml_metal_kargs_cpy & args, | ||
| 6548 | device const char * src0, | ||
| 6549 | device char * dst, | ||
| 6550 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 6551 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 6552 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 6553 | const int i03 = tgpig[2]; | ||
| 6554 | const int i02 = tgpig[1]; | ||
| 6555 | const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; | ||
| 6556 | const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; | ||
| 6557 | |||
| 6558 | const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; | ||
| 6559 | |||
| 6560 | const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); | ||
| 6561 | const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); | ||
| 6562 | const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; | ||
| 6563 | const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; | ||
| 6564 | |||
| 6565 | device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 6566 | |||
| 6567 | for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { | ||
| 6568 | device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); | ||
| 6569 | |||
| 6570 | quantize_func(src, dst_data[i00]); | ||
| 6571 | |||
| 6572 | break; | ||
| 6573 | } | ||
| 6574 | } | ||
| 6575 | |||
| 6576 | typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t; | ||
| 6577 | |||
| 6578 | template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>; | ||
| 6579 | template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>; | ||
| 6580 | template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>; | ||
| 6581 | template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>; | ||
| 6582 | template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>; | ||
| 6583 | template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>; | ||
| 6584 | |||
| 6585 | template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)> | ||
| 6586 | kernel void kernel_cpy_q_f32( | ||
| 6587 | constant ggml_metal_kargs_cpy & args, | ||
| 6588 | device const char * src0, | ||
| 6589 | device char * dst, | ||
| 6590 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 6591 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 6592 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 6593 | const int i03 = tgpig[2]; | ||
| 6594 | const int i02 = tgpig[1]; | ||
| 6595 | const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; | ||
| 6596 | const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; | ||
| 6597 | |||
| 6598 | const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; | ||
| 6599 | |||
| 6600 | const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); | ||
| 6601 | const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); | ||
| 6602 | const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; | ||
| 6603 | const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); | ||
| 6604 | |||
| 6605 | device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); | ||
| 6606 | device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 6607 | |||
| 6608 | for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { | ||
| 6609 | T4x4 temp; | ||
| 6610 | dequantize_func(src_data + i00/nl, i00%nl, temp); | ||
| 6611 | dst_data[i00] = temp; | ||
| 6612 | |||
| 6613 | break; | ||
| 6614 | } | ||
| 6615 | } | ||
| 6616 | |||
| 6617 | typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t; | ||
| 6618 | |||
| 6619 | template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>; | ||
| 6620 | template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>; | ||
| 6621 | template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>; | ||
| 6622 | template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>; | ||
| 6623 | template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>; | ||
| 6624 | |||
| 6625 | template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>; | ||
| 6626 | template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>; | ||
| 6627 | template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>; | ||
| 6628 | template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>; | ||
| 6629 | template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>; | ||
| 6630 | |||
| 6631 | kernel void kernel_concat( | ||
| 6632 | constant ggml_metal_kargs_concat & args, | ||
| 6633 | device const char * src0, | ||
| 6634 | device const char * src1, | ||
| 6635 | device char * dst, | ||
| 6636 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 6637 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 6638 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 6639 | |||
| 6640 | const int i3 = tgpig.z; | ||
| 6641 | const int i2 = tgpig.y; | ||
| 6642 | const int i1 = tgpig.x; | ||
| 6643 | |||
| 6644 | int o[4] = {0, 0, 0, 0}; | ||
| 6645 | o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); | ||
| 6646 | |||
| 6647 | device const float * x; | ||
| 6648 | |||
| 6649 | for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { | ||
| 6650 | if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { | ||
| 6651 | x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); | ||
| 6652 | } else { | ||
| 6653 | x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); | ||
| 6654 | } | ||
| 6655 | |||
| 6656 | device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); | ||
| 6657 | |||
| 6658 | *y = *x; | ||
| 6659 | } | ||
| 6660 | } | ||
| 6661 | |||
| 6662 | template<int nr0, typename args_t> | ||
| 6663 | void kernel_mul_mv_q2_K_f32_impl( | ||
| 6664 | args_t args, | ||
| 6665 | device const char * src0, | ||
| 6666 | device const char * src1, | ||
| 6667 | device char * dst, | ||
| 6668 | threadgroup char * shmem, | ||
| 6669 | uint3 tgpig, | ||
| 6670 | ushort tiisg, | ||
| 6671 | ushort sgitg) { | ||
| 6672 | const short NSG = FC_mul_mv_nsg; | ||
| 6673 | |||
| 6674 | const int nb = args.ne00/QK_K; | ||
| 6675 | |||
| 6676 | const int r0 = tgpig.x; | ||
| 6677 | const int r1 = tgpig.y; | ||
| 6678 | const int im = tgpig.z; | ||
| 6679 | |||
| 6680 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 6681 | |||
| 6682 | const uint i12 = im%args.ne12; | ||
| 6683 | const uint i13 = im/args.ne12; | ||
| 6684 | |||
| 6685 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 6686 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 6687 | |||
| 6688 | device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); | ||
| 6689 | device const float * y = (device const float *) (src1 + offset1); | ||
| 6690 | |||
| 6691 | float yl[32]; | ||
| 6692 | float sumf[nr0]={0.f}; | ||
| 6693 | |||
| 6694 | const short ix = tiisg/8; // 0...3 | ||
| 6695 | const short it = tiisg%8; // 0...7 | ||
| 6696 | const short iq = it/4; // 0 or 1 | ||
| 6697 | const short ir = it%4; // 0...3 | ||
| 6698 | const short is = (8*ir)/16;// 0 or 1 | ||
| 6699 | |||
| 6700 | device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; | ||
| 6701 | |||
| 6702 | for (int ib = ix; ib < nb; ib += 4) { | ||
| 6703 | float4 sumy = {0.f, 0.f, 0.f, 0.f}; | ||
| 6704 | for (short i = 0; i < 8; ++i) { | ||
| 6705 | yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; | ||
| 6706 | yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; | ||
| 6707 | yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; | ||
| 6708 | yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; | ||
| 6709 | } | ||
| 6710 | |||
| 6711 | device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; | ||
| 6712 | device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; | ||
| 6713 | device const half * dh = &x[ib].d; | ||
| 6714 | |||
| 6715 | for (short row = 0; row < nr0; row++) { | ||
| 6716 | float4 acc1 = {0.f, 0.f, 0.f, 0.f}; | ||
| 6717 | float4 acc2 = {0.f, 0.f, 0.f, 0.f}; | ||
| 6718 | for (int i = 0; i < 8; i += 2) { | ||
| 6719 | acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); | ||
| 6720 | acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); | ||
| 6721 | acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); | ||
| 6722 | acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); | ||
| 6723 | acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); | ||
| 6724 | acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); | ||
| 6725 | acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); | ||
| 6726 | acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); | ||
| 6727 | } | ||
| 6728 | float dall = dh[0]; | ||
| 6729 | float dmin = dh[1] * 1.f/16.f; | ||
| 6730 | sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + | ||
| 6731 | (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + | ||
| 6732 | (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + | ||
| 6733 | (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - | ||
| 6734 | dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); | ||
| 6735 | |||
| 6736 | qs += args.nb01/2; | ||
| 6737 | sc += args.nb01; | ||
| 6738 | dh += args.nb01/2; | ||
| 6739 | } | ||
| 6740 | |||
| 6741 | y4 += 4 * QK_K; | ||
| 6742 | } | ||
| 6743 | |||
| 6744 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 6745 | |||
| 6746 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 6747 | float sum_all = simd_sum(sumf[row]); | ||
| 6748 | if (tiisg == 0) { | ||
| 6749 | dst_f32[first_row + row] = sum_all; | ||
| 6750 | } | ||
| 6751 | } | ||
| 6752 | } | ||
| 6753 | |||
| 6754 | [[host_name("kernel_mul_mv_q2_K_f32")]] | ||
| 6755 | kernel void kernel_mul_mv_q2_K_f32( | ||
| 6756 | constant ggml_metal_kargs_mul_mv & args, | ||
| 6757 | device const char * src0, | ||
| 6758 | device const char * src1, | ||
| 6759 | device char * dst, | ||
| 6760 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 6761 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 6762 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 6763 | |||
| 6764 | kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); | ||
| 6765 | } | ||
| 6766 | |||
| 6767 | template<int nr0, typename args_t> | ||
| 6768 | void kernel_mul_mv_q3_K_f32_impl( | ||
| 6769 | args_t args, | ||
| 6770 | device const char * src0, | ||
| 6771 | device const char * src1, | ||
| 6772 | device char * dst, | ||
| 6773 | threadgroup char * shmem, | ||
| 6774 | uint3 tgpig, | ||
| 6775 | ushort tiisg, | ||
| 6776 | ushort sgitg) { | ||
| 6777 | const short NSG = FC_mul_mv_nsg; | ||
| 6778 | |||
| 6779 | const int nb = args.ne00/QK_K; | ||
| 6780 | |||
| 6781 | const int r0 = tgpig.x; | ||
| 6782 | const int r1 = tgpig.y; | ||
| 6783 | const int im = tgpig.z; | ||
| 6784 | |||
| 6785 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 6786 | |||
| 6787 | const uint i12 = im%args.ne12; | ||
| 6788 | const uint i13 = im/args.ne12; | ||
| 6789 | |||
| 6790 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 6791 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 6792 | |||
| 6793 | device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); | ||
| 6794 | device const float * yy = (device const float *) (src1 + offset1); | ||
| 6795 | |||
| 6796 | float yl[32]; | ||
| 6797 | |||
| 6798 | //const uint16_t kmask1 = 0x3030; | ||
| 6799 | //const uint16_t kmask2 = 0x0f0f; | ||
| 6800 | |||
| 6801 | const short tid = tiisg/4; | ||
| 6802 | const short ix = tiisg%4; | ||
| 6803 | const short ip = tid/4; // 0 or 1 | ||
| 6804 | const short il = 2*((tid%4)/2); // 0 or 2 | ||
| 6805 | const short ir = tid%2; | ||
| 6806 | const short l0 = 8*ir; | ||
| 6807 | |||
| 6808 | // One would think that the Metal compiler would figure out that ip and il can only have | ||
| 6809 | // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it | ||
| 6810 | // with these two tales. | ||
| 6811 | // | ||
| 6812 | // Possible masks for the high bit | ||
| 6813 | const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0 | ||
| 6814 | {0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2 | ||
| 6815 | {0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0 | ||
| 6816 | {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2 | ||
| 6817 | |||
| 6818 | // Possible masks for the low 2 bits | ||
| 6819 | const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}}; | ||
| 6820 | |||
| 6821 | const ushort4 hm = mm[2*ip + il/2]; | ||
| 6822 | |||
| 6823 | const short shift = 2*il; | ||
| 6824 | |||
| 6825 | const float v1 = il == 0 ? 4.f : 64.f; | ||
| 6826 | const float v2 = 4.f * v1; | ||
| 6827 | |||
| 6828 | const uint16_t s_shift1 = 4*ip; | ||
| 6829 | const uint16_t s_shift2 = s_shift1 + il; | ||
| 6830 | |||
| 6831 | const short q_offset = 32*ip + l0; | ||
| 6832 | const short y_offset = 128*ip + 32*il + l0; | ||
| 6833 | |||
| 6834 | device const float * y1 = yy + ix*QK_K + y_offset; | ||
| 6835 | |||
| 6836 | uint32_t scales32, aux32; | ||
| 6837 | thread uint16_t * scales16 = (thread uint16_t *)&scales32; | ||
| 6838 | thread const int8_t * scales = (thread const int8_t *)&scales32; | ||
| 6839 | |||
| 6840 | float sumf1[nr0] = {0.f}; | ||
| 6841 | float sumf2[nr0] = {0.f}; | ||
| 6842 | |||
| 6843 | for (int i = ix; i < nb; i += 4) { | ||
| 6844 | for (short l = 0; l < 8; ++l) { | ||
| 6845 | yl[l+ 0] = y1[l+ 0]; | ||
| 6846 | yl[l+ 8] = y1[l+16]; | ||
| 6847 | yl[l+16] = y1[l+32]; | ||
| 6848 | yl[l+24] = y1[l+48]; | ||
| 6849 | } | ||
| 6850 | |||
| 6851 | device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); | ||
| 6852 | device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); | ||
| 6853 | device const uint16_t * a = (device const uint16_t *)(x[i].scales); | ||
| 6854 | device const half * dh = &x[i].d; | ||
| 6855 | |||
| 6856 | for (short row = 0; row < nr0; ++row) { | ||
| 6857 | const float d_all = (float)dh[0]; | ||
| 6858 | |||
| 6859 | scales16[0] = a[4]; | ||
| 6860 | scales16[1] = a[5]; | ||
| 6861 | aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030; | ||
| 6862 | scales16[0] = a[il+0]; | ||
| 6863 | scales16[1] = a[il+1]; | ||
| 6864 | scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; | ||
| 6865 | |||
| 6866 | float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; | ||
| 6867 | for (short l = 0; l < 8; l += 2) { | ||
| 6868 | const int32_t qs = q[l/2]; | ||
| 6869 | s1 += yl[l+0] * (qs & qm[il/2][0]); | ||
| 6870 | s2 += yl[l+1] * (qs & qm[il/2][1]); | ||
| 6871 | s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]); | ||
| 6872 | s4 += yl[l+16] * (qs & qm[il/2][2]); | ||
| 6873 | s5 += yl[l+17] * (qs & qm[il/2][3]); | ||
| 6874 | s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]); | ||
| 6875 | } | ||
| 6876 | float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); | ||
| 6877 | float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); | ||
| 6878 | sumf1[row] += d1 * (scales[0] - 32); | ||
| 6879 | sumf2[row] += d2 * (scales[2] - 32); | ||
| 6880 | |||
| 6881 | s1 = s2 = s3 = s4 = s5 = s6 = 0; | ||
| 6882 | for (short l = 0; l < 8; l += 2) { | ||
| 6883 | const int32_t qs = q[l/2+8]; | ||
| 6884 | s1 += yl[l+8] * (qs & qm[il/2][0]); | ||
| 6885 | s2 += yl[l+9] * (qs & qm[il/2][1]); | ||
| 6886 | s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]); | ||
| 6887 | s4 += yl[l+24] * (qs & qm[il/2][2]); | ||
| 6888 | s5 += yl[l+25] * (qs & qm[il/2][3]); | ||
| 6889 | s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]); | ||
| 6890 | } | ||
| 6891 | d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1); | ||
| 6892 | d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2); | ||
| 6893 | sumf1[row] += d1 * (scales[1] - 32); | ||
| 6894 | sumf2[row] += d2 * (scales[3] - 32); | ||
| 6895 | |||
| 6896 | q += args.nb01/2; | ||
| 6897 | h += args.nb01/2; | ||
| 6898 | a += args.nb01/2; | ||
| 6899 | dh += args.nb01/2; | ||
| 6900 | } | ||
| 6901 | |||
| 6902 | y1 += 4 * QK_K; | ||
| 6903 | } | ||
| 6904 | |||
| 6905 | for (int row = 0; row < nr0; ++row) { | ||
| 6906 | const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); | ||
| 6907 | sumf1[row] = simd_sum(sumf); | ||
| 6908 | } | ||
| 6909 | |||
| 6910 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 6911 | |||
| 6912 | if (tiisg == 0) { | ||
| 6913 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 6914 | dst_f32[first_row + row] = sumf1[row]; | ||
| 6915 | } | ||
| 6916 | } | ||
| 6917 | } | ||
| 6918 | |||
| 6919 | [[host_name("kernel_mul_mv_q3_K_f32")]] | ||
| 6920 | kernel void kernel_mul_mv_q3_K_f32( | ||
| 6921 | constant ggml_metal_kargs_mul_mv & args, | ||
| 6922 | device const char * src0, | ||
| 6923 | device const char * src1, | ||
| 6924 | device char * dst, | ||
| 6925 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 6926 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 6927 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 6928 | |||
| 6929 | kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); | ||
| 6930 | } | ||
| 6931 | |||
| 6932 | template<int nr0, typename args_t> | ||
| 6933 | void kernel_mul_mv_q4_K_f32_impl( | ||
| 6934 | args_t args, | ||
| 6935 | device const char * src0, | ||
| 6936 | device const char * src1, | ||
| 6937 | device char * dst, | ||
| 6938 | threadgroup char * shmem, | ||
| 6939 | uint3 tgpig, | ||
| 6940 | ushort tiisg, | ||
| 6941 | ushort sgitg) { | ||
| 6942 | const short NSG = FC_mul_mv_nsg; | ||
| 6943 | |||
| 6944 | constexpr uint16_t kmask1 = 0x3f3f; | ||
| 6945 | constexpr uint16_t kmask2 = 0x0f0f; | ||
| 6946 | constexpr uint16_t kmask3 = 0xc0c0; | ||
| 6947 | |||
| 6948 | const short ix = tiisg/8; // 0...3 | ||
| 6949 | const short it = tiisg%8; // 0...7 | ||
| 6950 | const short iq = it/4; // 0 or 1 | ||
| 6951 | const short ir = it%4; // 0...3 | ||
| 6952 | |||
| 6953 | const int nb = args.ne00/QK_K; | ||
| 6954 | |||
| 6955 | const int r0 = tgpig.x; | ||
| 6956 | const int r1 = tgpig.y; | ||
| 6957 | const int im = tgpig.z; | ||
| 6958 | |||
| 6959 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 6960 | |||
| 6961 | const uint i12 = im%args.ne12; | ||
| 6962 | const uint i13 = im/args.ne12; | ||
| 6963 | |||
| 6964 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 6965 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 6966 | |||
| 6967 | device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); | ||
| 6968 | device const float * y = (device const float *) (src1 + offset1); | ||
| 6969 | |||
| 6970 | float yl[16]; | ||
| 6971 | float yh[16]; | ||
| 6972 | |||
| 6973 | float sumf[nr0]={0.f}; | ||
| 6974 | |||
| 6975 | device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; | ||
| 6976 | |||
| 6977 | uint16_t sc16[4]; | ||
| 6978 | thread const uint8_t * sc8 = (thread const uint8_t *)sc16; | ||
| 6979 | |||
| 6980 | for (int ib = ix; ib < nb; ib += 4) { | ||
| 6981 | float4 sumy = {0.f, 0.f, 0.f, 0.f}; | ||
| 6982 | |||
| 6983 | for (short i = 0; i < 8; ++i) { | ||
| 6984 | yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; | ||
| 6985 | yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; | ||
| 6986 | yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; | ||
| 6987 | yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; | ||
| 6988 | } | ||
| 6989 | |||
| 6990 | device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; | ||
| 6991 | device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; | ||
| 6992 | device const half * dh = &x[ib].d; | ||
| 6993 | |||
| 6994 | for (short row = 0; row < nr0; row++) { | ||
| 6995 | sc16[0] = sc[0] & kmask1; | ||
| 6996 | sc16[1] = sc[2] & kmask1; | ||
| 6997 | sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); | ||
| 6998 | sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); | ||
| 6999 | |||
| 7000 | device const uint16_t * q2 = q1 + 32; | ||
| 7001 | |||
| 7002 | float4 acc1 = {0.f, 0.f, 0.f, 0.f}; | ||
| 7003 | float4 acc2 = {0.f, 0.f, 0.f, 0.f}; | ||
| 7004 | |||
| 7005 | FOR_UNROLL (short i = 0; i < 4; ++i) { | ||
| 7006 | acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F); | ||
| 7007 | acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00); | ||
| 7008 | acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0); | ||
| 7009 | acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000); | ||
| 7010 | acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F); | ||
| 7011 | acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00); | ||
| 7012 | acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0); | ||
| 7013 | acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000); | ||
| 7014 | } | ||
| 7015 | |||
| 7016 | sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + | ||
| 7017 | (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + | ||
| 7018 | (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + | ||
| 7019 | (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - | ||
| 7020 | dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); | ||
| 7021 | |||
| 7022 | q1 += args.nb01/2; | ||
| 7023 | sc += args.nb01/2; | ||
| 7024 | dh += args.nb01/2; | ||
| 7025 | } | ||
| 7026 | |||
| 7027 | y4 += 4 * QK_K; | ||
| 7028 | } | ||
| 7029 | |||
| 7030 | device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; | ||
| 7031 | |||
| 7032 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7033 | float sum_all = simd_sum(sumf[row]); | ||
| 7034 | if (tiisg == 0) { | ||
| 7035 | dst_f32[first_row + row] = sum_all; | ||
| 7036 | } | ||
| 7037 | } | ||
| 7038 | } | ||
| 7039 | |||
| 7040 | [[host_name("kernel_mul_mv_q4_K_f32")]] | ||
| 7041 | kernel void kernel_mul_mv_q4_K_f32( | ||
| 7042 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7043 | device const char * src0, | ||
| 7044 | device const char * src1, | ||
| 7045 | device char * dst, | ||
| 7046 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7047 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7048 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7049 | |||
| 7050 | kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); | ||
| 7051 | } | ||
| 7052 | |||
| 7053 | template<int nr0, typename args_t> | ||
| 7054 | void kernel_mul_mv_q5_K_f32_impl( | ||
| 7055 | args_t args, | ||
| 7056 | device const char * src0, | ||
| 7057 | device const char * src1, | ||
| 7058 | device char * dst, | ||
| 7059 | threadgroup char * shmem, | ||
| 7060 | uint3 tgpig, | ||
| 7061 | ushort tiisg, | ||
| 7062 | ushort sgitg) { | ||
| 7063 | const short NSG = FC_mul_mv_nsg; | ||
| 7064 | |||
| 7065 | const int nb = args.ne00/QK_K; | ||
| 7066 | |||
| 7067 | const int r0 = tgpig.x; | ||
| 7068 | const int r1 = tgpig.y; | ||
| 7069 | const int im = tgpig.z; | ||
| 7070 | |||
| 7071 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7072 | |||
| 7073 | const uint i12 = im%args.ne12; | ||
| 7074 | const uint i13 = im/args.ne12; | ||
| 7075 | |||
| 7076 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7077 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7078 | |||
| 7079 | device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); | ||
| 7080 | device const float * yy = (device const float *) (src1 + offset1); | ||
| 7081 | |||
| 7082 | float sumf[nr0]={0.f}; | ||
| 7083 | |||
| 7084 | float yl[16], yh[16]; | ||
| 7085 | |||
| 7086 | constexpr uint16_t kmask1 = 0x3f3f; | ||
| 7087 | constexpr uint16_t kmask2 = 0x0f0f; | ||
| 7088 | constexpr uint16_t kmask3 = 0xc0c0; | ||
| 7089 | |||
| 7090 | const short tid = tiisg/4; | ||
| 7091 | const short ix = tiisg%4; | ||
| 7092 | const short iq = tid/4; | ||
| 7093 | const short ir = tid%4; | ||
| 7094 | |||
| 7095 | const short l0 = 8*ir; | ||
| 7096 | const short q_offset = 32*iq + l0; | ||
| 7097 | const short y_offset = 64*iq + l0; | ||
| 7098 | |||
| 7099 | const uint8_t hm1 = 1u << (2*iq); | ||
| 7100 | const uint8_t hm2 = hm1 << 1; | ||
| 7101 | const uint8_t hm3 = hm1 << 4; | ||
| 7102 | const uint8_t hm4 = hm2 << 4; | ||
| 7103 | |||
| 7104 | uint16_t sc16[4]; | ||
| 7105 | thread const uint8_t * sc8 = (thread const uint8_t *)sc16; | ||
| 7106 | |||
| 7107 | device const float * y1 = yy + ix*QK_K + y_offset; | ||
| 7108 | |||
| 7109 | for (int i = ix; i < nb; i += 4) { | ||
| 7110 | device const uint8_t * q1 = x[i].qs + q_offset; | ||
| 7111 | device const uint8_t * qh = x[i].qh + l0; | ||
| 7112 | device const half * dh = &x[i].d; | ||
| 7113 | device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; | ||
| 7114 | |||
| 7115 | device const float * y2 = y1 + 128; | ||
| 7116 | float4 sumy = {0.f, 0.f, 0.f, 0.f}; | ||
| 7117 | for (short l = 0; l < 8; ++l) { | ||
| 7118 | yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; | ||
| 7119 | yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; | ||
| 7120 | yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; | ||
| 7121 | yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; | ||
| 7122 | } | ||
| 7123 | |||
| 7124 | for (short row = 0; row < nr0; ++row) { | ||
| 7125 | device const uint8_t * q2 = q1 + 64; | ||
| 7126 | |||
| 7127 | sc16[0] = a[0] & kmask1; | ||
| 7128 | sc16[1] = a[2] & kmask1; | ||
| 7129 | sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); | ||
| 7130 | sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); | ||
| 7131 | |||
| 7132 | float4 acc1 = {0.f}; | ||
| 7133 | float4 acc2 = {0.f}; | ||
| 7134 | FOR_UNROLL (short l = 0; l < 8; ++l) { | ||
| 7135 | uint8_t h = qh[l]; | ||
| 7136 | acc1[0] += yl[l+0] * (q1[l] & 0x0F); | ||
| 7137 | acc1[1] += yl[l+8] * (q1[l] & 0xF0); | ||
| 7138 | acc1[2] += yh[l+0] * (q2[l] & 0x0F); | ||
| 7139 | acc1[3] += yh[l+8] * (q2[l] & 0xF0); | ||
| 7140 | acc2[0] += h & hm1 ? yl[l+0] : 0.f; | ||
| 7141 | acc2[1] += h & hm2 ? yl[l+8] : 0.f; | ||
| 7142 | acc2[2] += h & hm3 ? yh[l+0] : 0.f; | ||
| 7143 | acc2[3] += h & hm4 ? yh[l+8] : 0.f; | ||
| 7144 | } | ||
| 7145 | |||
| 7146 | sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + | ||
| 7147 | sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + | ||
| 7148 | sc8[4] * (acc1[2] + 16.f*acc2[2]) + | ||
| 7149 | sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - | ||
| 7150 | dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); | ||
| 7151 | |||
| 7152 | q1 += args.nb01; | ||
| 7153 | qh += args.nb01; | ||
| 7154 | dh += args.nb01/2; | ||
| 7155 | a += args.nb01/2; | ||
| 7156 | } | ||
| 7157 | |||
| 7158 | y1 += 4 * QK_K; | ||
| 7159 | } | ||
| 7160 | |||
| 7161 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7162 | |||
| 7163 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7164 | const float tot = simd_sum(sumf[row]); | ||
| 7165 | if (tiisg == 0) { | ||
| 7166 | dst_f32[first_row + row] = tot; | ||
| 7167 | } | ||
| 7168 | } | ||
| 7169 | } | ||
| 7170 | |||
| 7171 | [[host_name("kernel_mul_mv_q5_K_f32")]] | ||
| 7172 | kernel void kernel_mul_mv_q5_K_f32( | ||
| 7173 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7174 | device const char * src0, | ||
| 7175 | device const char * src1, | ||
| 7176 | device char * dst, | ||
| 7177 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7178 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7179 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7180 | |||
| 7181 | kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); | ||
| 7182 | } | ||
| 7183 | |||
| 7184 | template<int nr0, typename args_t> | ||
| 7185 | void kernel_mul_mv_q6_K_f32_impl( | ||
| 7186 | args_t args, | ||
| 7187 | device const char * src0, | ||
| 7188 | device const char * src1, | ||
| 7189 | device char * dst, | ||
| 7190 | threadgroup char * shmem, | ||
| 7191 | uint3 tgpig, | ||
| 7192 | ushort tiisg, | ||
| 7193 | ushort sgitg) { | ||
| 7194 | const short NSG = FC_mul_mv_nsg; | ||
| 7195 | |||
| 7196 | constexpr uint8_t kmask1 = 0x03; | ||
| 7197 | constexpr uint8_t kmask2 = 0x0C; | ||
| 7198 | constexpr uint8_t kmask3 = 0x30; | ||
| 7199 | constexpr uint8_t kmask4 = 0xC0; | ||
| 7200 | |||
| 7201 | const int nb = args.ne00/QK_K; | ||
| 7202 | |||
| 7203 | const int r0 = tgpig.x; | ||
| 7204 | const int r1 = tgpig.y; | ||
| 7205 | const int im = tgpig.z; | ||
| 7206 | |||
| 7207 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7208 | |||
| 7209 | const uint i12 = im%args.ne12; | ||
| 7210 | const uint i13 = im/args.ne12; | ||
| 7211 | |||
| 7212 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7213 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7214 | |||
| 7215 | device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); | ||
| 7216 | device const float * yy = (device const float *) (src1 + offset1); | ||
| 7217 | |||
| 7218 | float sumf[nr0] = { 0.f }; | ||
| 7219 | |||
| 7220 | float yl[16]; | ||
| 7221 | |||
| 7222 | const short tid = tiisg/2; | ||
| 7223 | const short ix = tiisg%2; | ||
| 7224 | const short ip = tid/8; // 0 or 1 | ||
| 7225 | const short il = tid%8; | ||
| 7226 | const short l0 = 4*il; | ||
| 7227 | const short is = 8*ip + l0/16; | ||
| 7228 | |||
| 7229 | const short y_offset = 128*ip + l0; | ||
| 7230 | const short q_offset_l = 64*ip + l0; | ||
| 7231 | const short q_offset_h = 32*ip + l0; | ||
| 7232 | |||
| 7233 | for (int i = ix; i < nb; i += 2) { | ||
| 7234 | device const uint8_t * q1 = x[i].ql + q_offset_l; | ||
| 7235 | device const uint8_t * q2 = q1 + 32; | ||
| 7236 | device const uint8_t * qh = x[i].qh + q_offset_h; | ||
| 7237 | device const int8_t * sc = x[i].scales + is; | ||
| 7238 | device const half * dh = &x[i].d; | ||
| 7239 | |||
| 7240 | device const float * y = yy + i * QK_K + y_offset; | ||
| 7241 | |||
| 7242 | for (short l = 0; l < 4; ++l) { | ||
| 7243 | yl[4*l + 0] = y[l + 0]; | ||
| 7244 | yl[4*l + 1] = y[l + 32]; | ||
| 7245 | yl[4*l + 2] = y[l + 64]; | ||
| 7246 | yl[4*l + 3] = y[l + 96]; | ||
| 7247 | } | ||
| 7248 | |||
| 7249 | for (short row = 0; row < nr0; ++row) { | ||
| 7250 | float4 sums = {0.f, 0.f, 0.f, 0.f}; | ||
| 7251 | |||
| 7252 | FOR_UNROLL (short l = 0; l < 4; ++l) { | ||
| 7253 | sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); | ||
| 7254 | sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); | ||
| 7255 | sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); | ||
| 7256 | sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); | ||
| 7257 | } | ||
| 7258 | |||
| 7259 | sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); | ||
| 7260 | |||
| 7261 | q1 += args.nb01; | ||
| 7262 | q2 += args.nb01; | ||
| 7263 | qh += args.nb01; | ||
| 7264 | sc += args.nb01; | ||
| 7265 | dh += args.nb01/2; | ||
| 7266 | } | ||
| 7267 | } | ||
| 7268 | |||
| 7269 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7270 | |||
| 7271 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7272 | float sum_all = simd_sum(sumf[row]); | ||
| 7273 | if (tiisg == 0) { | ||
| 7274 | dst_f32[first_row + row] = sum_all; | ||
| 7275 | } | ||
| 7276 | } | ||
| 7277 | } | ||
| 7278 | |||
| 7279 | [[host_name("kernel_mul_mv_q6_K_f32")]] | ||
| 7280 | kernel void kernel_mul_mv_q6_K_f32( | ||
| 7281 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7282 | device const char * src0, | ||
| 7283 | device const char * src1, | ||
| 7284 | device char * dst, | ||
| 7285 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7286 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7287 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7288 | |||
| 7289 | kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); | ||
| 7290 | } | ||
| 7291 | |||
| 7292 | // ======================= "True" 2-bit | ||
| 7293 | |||
| 7294 | template<int nr0, typename args_t> | ||
| 7295 | void kernel_mul_mv_iq2_xxs_f32_impl( | ||
| 7296 | args_t args, | ||
| 7297 | device const char * src0, | ||
| 7298 | device const char * src1, | ||
| 7299 | device char * dst, | ||
| 7300 | threadgroup char * shmem, | ||
| 7301 | uint3 tgpig, | ||
| 7302 | ushort tiisg, | ||
| 7303 | ushort sgitg) { | ||
| 7304 | const short NSG = FC_mul_mv_nsg; | ||
| 7305 | |||
| 7306 | const int nb = args.ne00/QK_K; | ||
| 7307 | |||
| 7308 | const int r0 = tgpig.x; | ||
| 7309 | const int r1 = tgpig.y; | ||
| 7310 | const int im = tgpig.z; | ||
| 7311 | |||
| 7312 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7313 | |||
| 7314 | const uint i12 = im%args.ne12; | ||
| 7315 | const uint i13 = im/args.ne12; | ||
| 7316 | |||
| 7317 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7318 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7319 | |||
| 7320 | device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); | ||
| 7321 | device const float * y = (device const float *) (src1 + offset1); | ||
| 7322 | |||
| 7323 | float yl[32]; | ||
| 7324 | float sumf[nr0]={0.f}; | ||
| 7325 | |||
| 7326 | const int nb32 = nb * (QK_K / 32); | ||
| 7327 | |||
| 7328 | threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); | ||
| 7329 | threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); | ||
| 7330 | { | ||
| 7331 | int nval = 4; | ||
| 7332 | int pos = (32*sgitg + tiisg)*nval; | ||
| 7333 | for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; | ||
| 7334 | nval = 2; | ||
| 7335 | pos = (32*sgitg + tiisg)*nval; | ||
| 7336 | for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; | ||
| 7337 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 7338 | } | ||
| 7339 | |||
| 7340 | const int ix = tiisg; | ||
| 7341 | |||
| 7342 | device const float * y4 = y + 32 * ix; | ||
| 7343 | |||
| 7344 | for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||
| 7345 | for (short i = 0; i < 32; ++i) { | ||
| 7346 | yl[i] = y4[i]; | ||
| 7347 | } | ||
| 7348 | |||
| 7349 | const int ibl = ib32 / (QK_K / 32); | ||
| 7350 | const int ib = ib32 % (QK_K / 32); | ||
| 7351 | |||
| 7352 | device const block_iq2_xxs * xr = x + ibl; | ||
| 7353 | device const uint16_t * q2 = xr->qs + 4 * ib; | ||
| 7354 | device const half * dh = &xr->d; | ||
| 7355 | |||
| 7356 | for (short row = 0; row < nr0; row++) { | ||
| 7357 | const float db = dh[0]; | ||
| 7358 | device const uint8_t * aux8 = (device const uint8_t *)q2; | ||
| 7359 | const uint32_t aux32 = q2[2] | (q2[3] << 16); | ||
| 7360 | const float d = db * (0.5f + (aux32 >> 28)); | ||
| 7361 | |||
| 7362 | float sum = 0; | ||
| 7363 | for (short l = 0; l < 4; ++l) { | ||
| 7364 | const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); | ||
| 7365 | const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; | ||
| 7366 | for (short j = 0; j < 8; ++j) { | ||
| 7367 | sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); | ||
| 7368 | } | ||
| 7369 | } | ||
| 7370 | sumf[row] += d * sum; | ||
| 7371 | |||
| 7372 | dh += args.nb01/2; | ||
| 7373 | q2 += args.nb01/2; | ||
| 7374 | } | ||
| 7375 | |||
| 7376 | y4 += 32 * 32; | ||
| 7377 | } | ||
| 7378 | |||
| 7379 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7380 | |||
| 7381 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7382 | float sum_all = simd_sum(sumf[row]); | ||
| 7383 | if (tiisg == 0) { | ||
| 7384 | dst_f32[first_row + row] = sum_all * 0.25f; | ||
| 7385 | } | ||
| 7386 | } | ||
| 7387 | } | ||
| 7388 | |||
| 7389 | [[host_name("kernel_mul_mv_iq2_xxs_f32")]] | ||
| 7390 | kernel void kernel_mul_mv_iq2_xxs_f32( | ||
| 7391 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7392 | device const char * src0, | ||
| 7393 | device const char * src1, | ||
| 7394 | device char * dst, | ||
| 7395 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 7396 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7397 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7398 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7399 | kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 7400 | } | ||
| 7401 | |||
| 7402 | template<int nr0, typename args_t> | ||
| 7403 | void kernel_mul_mv_iq2_xs_f32_impl( | ||
| 7404 | args_t args, | ||
| 7405 | device const char * src0, | ||
| 7406 | device const char * src1, | ||
| 7407 | device char * dst, | ||
| 7408 | threadgroup char * shmem, | ||
| 7409 | uint3 tgpig, | ||
| 7410 | ushort tiisg, | ||
| 7411 | ushort sgitg) { | ||
| 7412 | const short NSG = FC_mul_mv_nsg; | ||
| 7413 | |||
| 7414 | const int nb = args.ne00/QK_K; | ||
| 7415 | |||
| 7416 | const int r0 = tgpig.x; | ||
| 7417 | const int r1 = tgpig.y; | ||
| 7418 | const int im = tgpig.z; | ||
| 7419 | |||
| 7420 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7421 | |||
| 7422 | const uint i12 = im%args.ne12; | ||
| 7423 | const uint i13 = im/args.ne12; | ||
| 7424 | |||
| 7425 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7426 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7427 | |||
| 7428 | device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); | ||
| 7429 | device const float * y = (device const float *) (src1 + offset1); | ||
| 7430 | |||
| 7431 | float yl[32]; | ||
| 7432 | float sumf[nr0]={0.f}; | ||
| 7433 | |||
| 7434 | const int nb32 = nb * (QK_K / 32); | ||
| 7435 | |||
| 7436 | threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); | ||
| 7437 | threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); | ||
| 7438 | { | ||
| 7439 | int nval = 8; | ||
| 7440 | int pos = (32*sgitg + tiisg)*nval; | ||
| 7441 | for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; | ||
| 7442 | nval = 2; | ||
| 7443 | pos = (32*sgitg + tiisg)*nval; | ||
| 7444 | for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; | ||
| 7445 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 7446 | } | ||
| 7447 | |||
| 7448 | const int ix = tiisg; | ||
| 7449 | |||
| 7450 | device const float * y4 = y + 32 * ix; | ||
| 7451 | |||
| 7452 | for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||
| 7453 | for (short i = 0; i < 32; ++i) { | ||
| 7454 | yl[i] = y4[i]; | ||
| 7455 | } | ||
| 7456 | |||
| 7457 | const int ibl = ib32 / (QK_K / 32); | ||
| 7458 | const int ib = ib32 % (QK_K / 32); | ||
| 7459 | |||
| 7460 | device const block_iq2_xs * xr = x + ibl; | ||
| 7461 | device const uint16_t * q2 = xr->qs + 4 * ib; | ||
| 7462 | device const uint8_t * sc = xr->scales + ib; | ||
| 7463 | device const half * dh = &xr->d; | ||
| 7464 | |||
| 7465 | for (short row = 0; row < nr0; row++) { | ||
| 7466 | const float db = dh[0]; | ||
| 7467 | const uint8_t ls1 = sc[0] & 0xf; | ||
| 7468 | const uint8_t ls2 = sc[0] >> 4; | ||
| 7469 | const float d1 = db * (0.5f + ls1); | ||
| 7470 | const float d2 = db * (0.5f + ls2); | ||
| 7471 | |||
| 7472 | float sum1 = 0, sum2 = 0; | ||
| 7473 | for (short l = 0; l < 2; ++l) { | ||
| 7474 | const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); | ||
| 7475 | const uint8_t signs = ssigns[(q2[l] >> 9)]; | ||
| 7476 | for (short j = 0; j < 8; ++j) { | ||
| 7477 | sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); | ||
| 7478 | } | ||
| 7479 | } | ||
| 7480 | for (short l = 2; l < 4; ++l) { | ||
| 7481 | const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); | ||
| 7482 | const uint8_t signs = ssigns[(q2[l] >> 9)]; | ||
| 7483 | for (short j = 0; j < 8; ++j) { | ||
| 7484 | sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); | ||
| 7485 | } | ||
| 7486 | } | ||
| 7487 | sumf[row] += d1 * sum1 + d2 * sum2; | ||
| 7488 | |||
| 7489 | dh += args.nb01/2; | ||
| 7490 | q2 += args.nb01/2; | ||
| 7491 | sc += args.nb01; | ||
| 7492 | } | ||
| 7493 | |||
| 7494 | y4 += 32 * 32; | ||
| 7495 | } | ||
| 7496 | |||
| 7497 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7498 | |||
| 7499 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7500 | float sum_all = simd_sum(sumf[row]); | ||
| 7501 | if (tiisg == 0) { | ||
| 7502 | dst_f32[first_row + row] = sum_all * 0.25f; | ||
| 7503 | } | ||
| 7504 | } | ||
| 7505 | } | ||
| 7506 | |||
| 7507 | [[host_name("kernel_mul_mv_iq2_xs_f32")]] | ||
| 7508 | kernel void kernel_mul_mv_iq2_xs_f32( | ||
| 7509 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7510 | device const char * src0, | ||
| 7511 | device const char * src1, | ||
| 7512 | device char * dst, | ||
| 7513 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 7514 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7515 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7516 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7517 | |||
| 7518 | kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 7519 | } | ||
| 7520 | |||
| 7521 | template<int nr0, typename args_t> | ||
| 7522 | void kernel_mul_mv_iq3_xxs_f32_impl( | ||
| 7523 | args_t args, | ||
| 7524 | device const char * src0, | ||
| 7525 | device const char * src1, | ||
| 7526 | device char * dst, | ||
| 7527 | threadgroup char * shmem, | ||
| 7528 | uint3 tgpig, | ||
| 7529 | ushort tiisg, | ||
| 7530 | ushort sgitg) { | ||
| 7531 | const short NSG = FC_mul_mv_nsg; | ||
| 7532 | |||
| 7533 | const int nb = args.ne00/QK_K; | ||
| 7534 | |||
| 7535 | const int r0 = tgpig.x; | ||
| 7536 | const int r1 = tgpig.y; | ||
| 7537 | const int im = tgpig.z; | ||
| 7538 | |||
| 7539 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7540 | |||
| 7541 | const uint i12 = im%args.ne12; | ||
| 7542 | const uint i13 = im/args.ne12; | ||
| 7543 | |||
| 7544 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7545 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7546 | |||
| 7547 | device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); | ||
| 7548 | device const float * y = (device const float *) (src1 + offset1); | ||
| 7549 | |||
| 7550 | float yl[32]; | ||
| 7551 | float sumf[nr0]={0.f}; | ||
| 7552 | |||
| 7553 | const int nb32 = nb * (QK_K / 32); | ||
| 7554 | |||
| 7555 | threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); | ||
| 7556 | threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); | ||
| 7557 | { | ||
| 7558 | int nval = 4; | ||
| 7559 | int pos = (32*sgitg + tiisg)*nval; | ||
| 7560 | for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; | ||
| 7561 | nval = 2; | ||
| 7562 | pos = (32*sgitg + tiisg)*nval; | ||
| 7563 | for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; | ||
| 7564 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 7565 | } | ||
| 7566 | |||
| 7567 | const int ix = tiisg; | ||
| 7568 | |||
| 7569 | device const float * y4 = y + 32 * ix; | ||
| 7570 | |||
| 7571 | for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||
| 7572 | for (short i = 0; i < 32; ++i) { | ||
| 7573 | yl[i] = y4[i]; | ||
| 7574 | } | ||
| 7575 | |||
| 7576 | const int ibl = ib32 / (QK_K / 32); | ||
| 7577 | const int ib = ib32 % (QK_K / 32); | ||
| 7578 | |||
| 7579 | device const block_iq3_xxs * xr = x + ibl; | ||
| 7580 | device const uint8_t * q3 = xr->qs + 8 * ib; | ||
| 7581 | device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; | ||
| 7582 | device const half * dh = &xr->d; | ||
| 7583 | |||
| 7584 | for (short row = 0; row < nr0; row++) { | ||
| 7585 | const float db = dh[0]; | ||
| 7586 | const uint32_t aux32 = gas[0] | (gas[1] << 16); | ||
| 7587 | const float d = db * (0.5f + (aux32 >> 28)); | ||
| 7588 | |||
| 7589 | float2 sum = {0}; | ||
| 7590 | for (short l = 0; l < 4; ++l) { | ||
| 7591 | const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); | ||
| 7592 | const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); | ||
| 7593 | const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; | ||
| 7594 | for (short j = 0; j < 4; ++j) { | ||
| 7595 | sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); | ||
| 7596 | sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); | ||
| 7597 | } | ||
| 7598 | } | ||
| 7599 | sumf[row] += d * (sum[0] + sum[1]); | ||
| 7600 | |||
| 7601 | dh += args.nb01/2; | ||
| 7602 | q3 += args.nb01; | ||
| 7603 | gas += args.nb01/2; | ||
| 7604 | } | ||
| 7605 | |||
| 7606 | y4 += 32 * 32; | ||
| 7607 | } | ||
| 7608 | |||
| 7609 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7610 | |||
| 7611 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7612 | float sum_all = simd_sum(sumf[row]); | ||
| 7613 | if (tiisg == 0) { | ||
| 7614 | dst_f32[first_row + row] = sum_all * 0.5f; | ||
| 7615 | } | ||
| 7616 | } | ||
| 7617 | } | ||
| 7618 | |||
| 7619 | [[host_name("kernel_mul_mv_iq3_xxs_f32")]] | ||
| 7620 | kernel void kernel_mul_mv_iq3_xxs_f32( | ||
| 7621 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7622 | device const char * src0, | ||
| 7623 | device const char * src1, | ||
| 7624 | device char * dst, | ||
| 7625 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 7626 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7627 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7628 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7629 | |||
| 7630 | kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 7631 | } | ||
| 7632 | |||
| 7633 | template<int nr0, typename args_t> | ||
| 7634 | void kernel_mul_mv_iq3_s_f32_impl( | ||
| 7635 | args_t args, | ||
| 7636 | device const char * src0, | ||
| 7637 | device const char * src1, | ||
| 7638 | device char * dst, | ||
| 7639 | threadgroup char * shmem, | ||
| 7640 | uint3 tgpig, | ||
| 7641 | ushort tiisg, | ||
| 7642 | ushort sgitg) { | ||
| 7643 | const short NSG = FC_mul_mv_nsg; | ||
| 7644 | |||
| 7645 | const int nb = args.ne00/QK_K; | ||
| 7646 | |||
| 7647 | const int r0 = tgpig.x; | ||
| 7648 | const int r1 = tgpig.y; | ||
| 7649 | const int im = tgpig.z; | ||
| 7650 | |||
| 7651 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7652 | |||
| 7653 | const uint i12 = im%args.ne12; | ||
| 7654 | const uint i13 = im/args.ne12; | ||
| 7655 | |||
| 7656 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7657 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7658 | |||
| 7659 | device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); | ||
| 7660 | device const float * y = (device const float *) (src1 + offset1); | ||
| 7661 | |||
| 7662 | float yl[32]; | ||
| 7663 | float sumf[nr0]={0.f}; | ||
| 7664 | |||
| 7665 | const int nb32 = nb * (QK_K / 32); | ||
| 7666 | |||
| 7667 | threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; | ||
| 7668 | { | ||
| 7669 | int nval = 8; | ||
| 7670 | int pos = (32*sgitg + tiisg)*nval; | ||
| 7671 | for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; | ||
| 7672 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 7673 | } | ||
| 7674 | |||
| 7675 | const int ix = tiisg; | ||
| 7676 | |||
| 7677 | device const float * y4 = y + 32 * ix; | ||
| 7678 | |||
| 7679 | for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||
| 7680 | for (short i = 0; i < 32; ++i) { | ||
| 7681 | yl[i] = y4[i]; | ||
| 7682 | } | ||
| 7683 | |||
| 7684 | const int ibl = ib32 / (QK_K / 32); | ||
| 7685 | const int ib = ib32 % (QK_K / 32); | ||
| 7686 | |||
| 7687 | device const block_iq3_s * xr = x + ibl; | ||
| 7688 | device const uint8_t * qs = xr->qs + 8 * ib; | ||
| 7689 | device const uint8_t * qh = xr->qh + ib; | ||
| 7690 | device const uint8_t * sc = xr->scales + (ib/2); | ||
| 7691 | device const uint8_t * signs = xr->signs + 4 * ib; | ||
| 7692 | device const half * dh = &xr->d; | ||
| 7693 | |||
| 7694 | for (short row = 0; row < nr0; row++) { | ||
| 7695 | const float db = dh[0]; | ||
| 7696 | const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); | ||
| 7697 | |||
| 7698 | float2 sum = {0}; | ||
| 7699 | for (short l = 0; l < 4; ++l) { | ||
| 7700 | const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; | ||
| 7701 | const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; | ||
| 7702 | const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); | ||
| 7703 | const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); | ||
| 7704 | for (short j = 0; j < 4; ++j) { | ||
| 7705 | sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); | ||
| 7706 | sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); | ||
| 7707 | } | ||
| 7708 | } | ||
| 7709 | sumf[row] += d * (sum[0] + sum[1]); | ||
| 7710 | |||
| 7711 | dh += args.nb01/2; | ||
| 7712 | qs += args.nb01; | ||
| 7713 | qh += args.nb01; | ||
| 7714 | sc += args.nb01; | ||
| 7715 | signs += args.nb01; | ||
| 7716 | } | ||
| 7717 | |||
| 7718 | y4 += 32 * 32; | ||
| 7719 | } | ||
| 7720 | |||
| 7721 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7722 | |||
| 7723 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7724 | float sum_all = simd_sum(sumf[row]); | ||
| 7725 | if (tiisg == 0) { | ||
| 7726 | dst_f32[first_row + row] = sum_all; | ||
| 7727 | } | ||
| 7728 | } | ||
| 7729 | } | ||
| 7730 | |||
| 7731 | [[host_name("kernel_mul_mv_iq3_s_f32")]] | ||
| 7732 | kernel void kernel_mul_mv_iq3_s_f32( | ||
| 7733 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7734 | device const char * src0, | ||
| 7735 | device const char * src1, | ||
| 7736 | device char * dst, | ||
| 7737 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 7738 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7739 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7740 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7741 | |||
| 7742 | kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 7743 | } | ||
| 7744 | |||
| 7745 | template<int nr0, typename args_t> | ||
| 7746 | void kernel_mul_mv_iq2_s_f32_impl( | ||
| 7747 | args_t args, | ||
| 7748 | device const char * src0, | ||
| 7749 | device const char * src1, | ||
| 7750 | device char * dst, | ||
| 7751 | threadgroup char * shmem, | ||
| 7752 | uint3 tgpig, | ||
| 7753 | ushort tiisg, | ||
| 7754 | ushort sgitg) { | ||
| 7755 | const short NSG = FC_mul_mv_nsg; | ||
| 7756 | |||
| 7757 | const int nb = args.ne00/QK_K; | ||
| 7758 | |||
| 7759 | const int r0 = tgpig.x; | ||
| 7760 | const int r1 = tgpig.y; | ||
| 7761 | const int im = tgpig.z; | ||
| 7762 | |||
| 7763 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7764 | |||
| 7765 | const uint i12 = im%args.ne12; | ||
| 7766 | const uint i13 = im/args.ne12; | ||
| 7767 | |||
| 7768 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7769 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7770 | |||
| 7771 | device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); | ||
| 7772 | device const float * y = (device const float *) (src1 + offset1); | ||
| 7773 | |||
| 7774 | float yl[32]; | ||
| 7775 | float sumf[nr0]={0.f}; | ||
| 7776 | |||
| 7777 | const int nb32 = nb * (QK_K / 32); | ||
| 7778 | |||
| 7779 | //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; | ||
| 7780 | //{ | ||
| 7781 | // int nval = 32; | ||
| 7782 | // int pos = (32*sgitg + tiisg)*nval; | ||
| 7783 | // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; | ||
| 7784 | // threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 7785 | //} | ||
| 7786 | |||
| 7787 | const short ix = tiisg; | ||
| 7788 | |||
| 7789 | device const float * y4 = y + 32 * ix; | ||
| 7790 | |||
| 7791 | for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||
| 7792 | for (short i = 0; i < 32; ++i) { | ||
| 7793 | yl[i] = y4[i]; | ||
| 7794 | } | ||
| 7795 | |||
| 7796 | const int ibl = ib32 / (QK_K / 32); | ||
| 7797 | const int ib = ib32 % (QK_K / 32); | ||
| 7798 | |||
| 7799 | device const block_iq2_s * xr = x + ibl; | ||
| 7800 | device const uint8_t * qs = xr->qs + 4 * ib; | ||
| 7801 | device const uint8_t * qh = xr->qh + ib; | ||
| 7802 | device const uint8_t * sc = xr->scales + ib; | ||
| 7803 | device const uint8_t * signs = qs + QK_K/8; | ||
| 7804 | device const half * dh = &xr->d; | ||
| 7805 | |||
| 7806 | for (short row = 0; row < nr0; row++) { | ||
| 7807 | const float db = dh[0]; | ||
| 7808 | const float d1 = db * (0.5f + (sc[0] & 0xf)); | ||
| 7809 | const float d2 = db * (0.5f + (sc[0] >> 4)); | ||
| 7810 | |||
| 7811 | float2 sum = {0}; | ||
| 7812 | for (short l = 0; l < 2; ++l) { | ||
| 7813 | //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); | ||
| 7814 | //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); | ||
| 7815 | constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); | ||
| 7816 | constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); | ||
| 7817 | for (short j = 0; j < 8; ++j) { | ||
| 7818 | sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); | ||
| 7819 | sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); | ||
| 7820 | } | ||
| 7821 | } | ||
| 7822 | sumf[row] += d1 * sum[0] + d2 * sum[1]; | ||
| 7823 | |||
| 7824 | dh += args.nb01/2; | ||
| 7825 | qs += args.nb01; | ||
| 7826 | qh += args.nb01; | ||
| 7827 | sc += args.nb01; | ||
| 7828 | signs += args.nb01; | ||
| 7829 | } | ||
| 7830 | |||
| 7831 | y4 += 32 * 32; | ||
| 7832 | } | ||
| 7833 | |||
| 7834 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7835 | |||
| 7836 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7837 | float sum_all = simd_sum(sumf[row]); | ||
| 7838 | if (tiisg == 0) { | ||
| 7839 | dst_f32[first_row + row] = sum_all * 0.25f; | ||
| 7840 | } | ||
| 7841 | } | ||
| 7842 | } | ||
| 7843 | |||
| 7844 | [[host_name("kernel_mul_mv_iq2_s_f32")]] | ||
| 7845 | kernel void kernel_mul_mv_iq2_s_f32( | ||
| 7846 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7847 | device const char * src0, | ||
| 7848 | device const char * src1, | ||
| 7849 | device char * dst, | ||
| 7850 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 7851 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7852 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7853 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7854 | |||
| 7855 | kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 7856 | } | ||
| 7857 | |||
| 7858 | template<int nr0, typename args_t> | ||
| 7859 | void kernel_mul_mv_iq1_s_f32_impl( | ||
| 7860 | args_t args, | ||
| 7861 | device const char * src0, | ||
| 7862 | device const char * src1, | ||
| 7863 | device char * dst, | ||
| 7864 | threadgroup char * shmem, | ||
| 7865 | uint3 tgpig, | ||
| 7866 | ushort tiisg, | ||
| 7867 | ushort sgitg) { | ||
| 7868 | const short NSG = FC_mul_mv_nsg; | ||
| 7869 | |||
| 7870 | const int nb = args.ne00/QK_K; | ||
| 7871 | |||
| 7872 | const int r0 = tgpig.x; | ||
| 7873 | const int r1 = tgpig.y; | ||
| 7874 | const int im = tgpig.z; | ||
| 7875 | |||
| 7876 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7877 | |||
| 7878 | const uint i12 = im%args.ne12; | ||
| 7879 | const uint i13 = im/args.ne12; | ||
| 7880 | |||
| 7881 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7882 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7883 | |||
| 7884 | device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); | ||
| 7885 | device const float * y = (device const float *) (src1 + offset1); | ||
| 7886 | |||
| 7887 | float yl[32]; | ||
| 7888 | float sumf[nr0]={0.f}; | ||
| 7889 | |||
| 7890 | const int nb32 = nb * (QK_K / 32); | ||
| 7891 | |||
| 7892 | const short ix = tiisg; | ||
| 7893 | |||
| 7894 | device const float * y4 = y + 32 * ix; | ||
| 7895 | |||
| 7896 | for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||
| 7897 | float sumy = 0; | ||
| 7898 | for (short i = 0; i < 32; ++i) { | ||
| 7899 | yl[i] = y4[i]; | ||
| 7900 | sumy += yl[i]; | ||
| 7901 | } | ||
| 7902 | |||
| 7903 | const int ibl = ib32 / (QK_K / 32); | ||
| 7904 | const int ib = ib32 % (QK_K / 32); | ||
| 7905 | |||
| 7906 | device const block_iq1_s * xr = x + ibl; | ||
| 7907 | device const uint8_t * qs = xr->qs + 4 * ib; | ||
| 7908 | device const uint16_t * qh = xr->qh + ib; | ||
| 7909 | device const half * dh = &xr->d; | ||
| 7910 | |||
| 7911 | for (short row = 0; row < nr0; row++) { | ||
| 7912 | constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); | ||
| 7913 | constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); | ||
| 7914 | constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); | ||
| 7915 | constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); | ||
| 7916 | |||
| 7917 | float sum = 0; | ||
| 7918 | for (short j = 0; j < 4; ++j) { | ||
| 7919 | sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) | ||
| 7920 | + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) | ||
| 7921 | + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) | ||
| 7922 | + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); | ||
| 7923 | } | ||
| 7924 | sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); | ||
| 7925 | |||
| 7926 | dh += args.nb01/2; | ||
| 7927 | qs += args.nb01; | ||
| 7928 | qh += args.nb01/2; | ||
| 7929 | } | ||
| 7930 | |||
| 7931 | y4 += 32 * 32; | ||
| 7932 | } | ||
| 7933 | |||
| 7934 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 7935 | |||
| 7936 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 7937 | float sum_all = simd_sum(sumf[row]); | ||
| 7938 | if (tiisg == 0) { | ||
| 7939 | dst_f32[first_row + row] = sum_all; | ||
| 7940 | } | ||
| 7941 | } | ||
| 7942 | } | ||
| 7943 | |||
| 7944 | [[host_name("kernel_mul_mv_iq1_s_f32")]] | ||
| 7945 | kernel void kernel_mul_mv_iq1_s_f32( | ||
| 7946 | constant ggml_metal_kargs_mul_mv & args, | ||
| 7947 | device const char * src0, | ||
| 7948 | device const char * src1, | ||
| 7949 | device char * dst, | ||
| 7950 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 7951 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 7952 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 7953 | |||
| 7954 | kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); | ||
| 7955 | } | ||
| 7956 | |||
| 7957 | template<int nr0, typename args_t> | ||
| 7958 | void kernel_mul_mv_iq1_m_f32_impl( | ||
| 7959 | args_t args, | ||
| 7960 | device const char * src0, | ||
| 7961 | device const char * src1, | ||
| 7962 | device char * dst, | ||
| 7963 | threadgroup char * shmem, | ||
| 7964 | uint3 tgpig, | ||
| 7965 | ushort tiisg, | ||
| 7966 | ushort sgitg) { | ||
| 7967 | const short NSG = FC_mul_mv_nsg; | ||
| 7968 | |||
| 7969 | const int nb = args.ne00/QK_K; | ||
| 7970 | |||
| 7971 | const int r0 = tgpig.x; | ||
| 7972 | const int r1 = tgpig.y; | ||
| 7973 | const int im = tgpig.z; | ||
| 7974 | |||
| 7975 | const int first_row = (r0 * NSG + sgitg) * nr0; | ||
| 7976 | |||
| 7977 | const uint i12 = im%args.ne12; | ||
| 7978 | const uint i13 = im/args.ne12; | ||
| 7979 | |||
| 7980 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 7981 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 7982 | |||
| 7983 | device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); | ||
| 7984 | device const float * y = (device const float *) (src1 + offset1); | ||
| 7985 | |||
| 7986 | float yl[32]; | ||
| 7987 | float sumf[nr0]={0.f}; | ||
| 7988 | |||
| 7989 | const int nb32 = nb * (QK_K / 32); | ||
| 7990 | |||
| 7991 | const short ix = tiisg; | ||
| 7992 | |||
| 7993 | device const float * y4 = y + 32 * ix; | ||
| 7994 | |||
| 7995 | iq1m_scale_t scale; | ||
| 7996 | |||
| 7997 | for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||
| 7998 | float4 sumy = {0.f}; | ||
| 7999 | for (short i = 0; i < 8; ++i) { | ||
| 8000 | yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; | ||
| 8001 | yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; | ||
| 8002 | yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; | ||
| 8003 | yl[i+24] = y4[i+24]; sumy[3] += yl[i+24]; | ||
| 8004 | } | ||
| 8005 | |||
| 8006 | const int ibl = ib32 / (QK_K / 32); | ||
| 8007 | const int ib = ib32 % (QK_K / 32); | ||
| 8008 | |||
| 8009 | device const block_iq1_m * xr = x + ibl; | ||
| 8010 | device const uint8_t * qs = xr->qs + 4 * ib; | ||
| 8011 | device const uint8_t * qh = xr->qh + 2 * ib; | ||
| 8012 | device const uint16_t * sc = (device const uint16_t *)xr->scales; | ||
| 8013 | |||
| 8014 | for (short row = 0; row < nr0; row++) { | ||
| 8015 | scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); | ||
| 8016 | |||
| 8017 | constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); | ||
| 8018 | constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700))); | ||
| 8019 | constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700))); | ||
| 8020 | constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); | ||
| 8021 | |||
| 8022 | float2 sum = {0.f}; | ||
| 8023 | for (short j = 0; j < 4; ++j) { | ||
| 8024 | sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) | ||
| 8025 | + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); | ||
| 8026 | sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) | ||
| 8027 | + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); | ||
| 8028 | } | ||
| 8029 | const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); | ||
| 8030 | const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA); | ||
| 8031 | |||
| 8032 | sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + | ||
| 8033 | (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); | ||
| 8034 | |||
| 8035 | sc += args.nb01/2; | ||
| 8036 | qs += args.nb01; | ||
| 8037 | qh += args.nb01; | ||
| 8038 | } | ||
| 8039 | |||
| 8040 | y4 += 32 * 32; | ||
| 8041 | } | ||
| 8042 | |||
| 8043 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 8044 | |||
| 8045 | for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { | ||
| 8046 | float sum_all = simd_sum(sumf[row]); | ||
| 8047 | if (tiisg == 0) { | ||
| 8048 | dst_f32[first_row + row] = sum_all; | ||
| 8049 | } | ||
| 8050 | } | ||
| 8051 | } | ||
| 8052 | |||
| 8053 | [[host_name("kernel_mul_mv_iq1_m_f32")]] | ||
| 8054 | kernel void kernel_mul_mv_iq1_m_f32( | ||
| 8055 | constant ggml_metal_kargs_mul_mv & args, | ||
| 8056 | device const char * src0, | ||
| 8057 | device const char * src1, | ||
| 8058 | device char * dst, | ||
| 8059 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8060 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 8061 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 8062 | |||
| 8063 | kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); | ||
| 8064 | } | ||
| 8065 | |||
| 8066 | template<int NR0, typename args_t> | ||
| 8067 | void kernel_mul_mv_iq4_nl_f32_impl( | ||
| 8068 | args_t args, | ||
| 8069 | device const char * src0, | ||
| 8070 | device const char * src1, | ||
| 8071 | device char * dst, | ||
| 8072 | threadgroup char * shmem, | ||
| 8073 | uint3 tgpig, | ||
| 8074 | ushort tiisg, | ||
| 8075 | ushort sgitg) { | ||
| 8076 | const short NSG = FC_mul_mv_nsg; | ||
| 8077 | |||
| 8078 | threadgroup float * shmem_f32 = (threadgroup float *) shmem; | ||
| 8079 | |||
| 8080 | const int r0 = tgpig.x; | ||
| 8081 | const int r1 = tgpig.y; | ||
| 8082 | const int im = tgpig.z; | ||
| 8083 | |||
| 8084 | const int first_row = (r0 * NSG + sgitg) * NR0; | ||
| 8085 | |||
| 8086 | const uint i12 = im%args.ne12; | ||
| 8087 | const uint i13 = im/args.ne12; | ||
| 8088 | |||
| 8089 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 8090 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 8091 | |||
| 8092 | device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); | ||
| 8093 | device const float * y = (device const float *) (src1 + offset1); | ||
| 8094 | |||
| 8095 | const int nb = args.ne00/QK4_NL; | ||
| 8096 | const int ns01 = args.nb01/args.nb00; | ||
| 8097 | |||
| 8098 | const short ix = tiisg/2; // 0...15 | ||
| 8099 | const short it = tiisg%2; // 0 or 1 | ||
| 8100 | |||
| 8101 | shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; | ||
| 8102 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8103 | |||
| 8104 | float4 yl[4]; | ||
| 8105 | float sumf[NR0]={0.f}; | ||
| 8106 | |||
| 8107 | device const float * yb = y + ix*QK4_NL + it*8; | ||
| 8108 | |||
| 8109 | uint32_t aux32[2]; | ||
| 8110 | thread const uint8_t * q8 = (thread const uint8_t *)aux32; | ||
| 8111 | |||
| 8112 | float4 qf1, qf2; | ||
| 8113 | |||
| 8114 | // [TAG_MUL_MV_WEIRD] | ||
| 8115 | for (int ib = ix; ib < nb && ib < ns01; ib += 16) { | ||
| 8116 | device const float4 * y4 = (device const float4 *)yb; | ||
| 8117 | yl[0] = y4[0]; | ||
| 8118 | yl[1] = y4[4]; | ||
| 8119 | yl[2] = y4[1]; | ||
| 8120 | yl[3] = y4[5]; | ||
| 8121 | |||
| 8122 | for (short row = 0; row < NR0; row++) { | ||
| 8123 | device const block_iq4_nl & xb = x[row*ns01 + ib]; | ||
| 8124 | device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); | ||
| 8125 | |||
| 8126 | float4 acc1 = {0.f}, acc2 = {0.f}; | ||
| 8127 | |||
| 8128 | aux32[0] = q4[0] | (q4[1] << 16); | ||
| 8129 | aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; | ||
| 8130 | aux32[0] &= 0x0f0f0f0f; | ||
| 8131 | qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; | ||
| 8132 | qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; | ||
| 8133 | acc1 += yl[0] * qf1; | ||
| 8134 | acc2 += yl[1] * qf2; | ||
| 8135 | |||
| 8136 | aux32[0] = q4[2] | (q4[3] << 16); | ||
| 8137 | aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; | ||
| 8138 | aux32[0] &= 0x0f0f0f0f; | ||
| 8139 | qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; | ||
| 8140 | qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; | ||
| 8141 | acc1 += yl[2] * qf1; | ||
| 8142 | acc2 += yl[3] * qf2; | ||
| 8143 | |||
| 8144 | acc1 += acc2; | ||
| 8145 | |||
| 8146 | sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); | ||
| 8147 | } | ||
| 8148 | |||
| 8149 | yb += 16 * QK4_NL; | ||
| 8150 | } | ||
| 8151 | |||
| 8152 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 8153 | |||
| 8154 | for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { | ||
| 8155 | float sum_all = simd_sum(sumf[row]); | ||
| 8156 | if (tiisg == 0) { | ||
| 8157 | dst_f32[first_row + row] = sum_all; | ||
| 8158 | } | ||
| 8159 | } | ||
| 8160 | } | ||
| 8161 | |||
| 8162 | [[host_name("kernel_mul_mv_iq4_nl_f32")]] | ||
| 8163 | kernel void kernel_mul_mv_iq4_nl_f32( | ||
| 8164 | constant ggml_metal_kargs_mul_mv & args, | ||
| 8165 | device const char * src0, | ||
| 8166 | device const char * src1, | ||
| 8167 | device char * dst, | ||
| 8168 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 8169 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8170 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 8171 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 8172 | |||
| 8173 | kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 8174 | } | ||
| 8175 | |||
| 8176 | template<int NR0, typename args_t> | ||
| 8177 | void kernel_mul_mv_iq4_xs_f32_impl( | ||
| 8178 | args_t args, | ||
| 8179 | device const char * src0, | ||
| 8180 | device const char * src1, | ||
| 8181 | device char * dst, | ||
| 8182 | threadgroup char * shmem, | ||
| 8183 | uint3 tgpig, | ||
| 8184 | ushort tiisg, | ||
| 8185 | ushort sgitg) { | ||
| 8186 | const short NSG = FC_mul_mv_nsg; | ||
| 8187 | |||
| 8188 | threadgroup float * shmem_f32 = (threadgroup float *) shmem; | ||
| 8189 | |||
| 8190 | const int r0 = tgpig.x; | ||
| 8191 | const int r1 = tgpig.y; | ||
| 8192 | const int im = tgpig.z; | ||
| 8193 | const int first_row = (r0 * NSG + sgitg) * NR0; | ||
| 8194 | |||
| 8195 | const uint i12 = im%args.ne12; | ||
| 8196 | const uint i13 = im/args.ne12; | ||
| 8197 | |||
| 8198 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 8199 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 8200 | |||
| 8201 | device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); | ||
| 8202 | device const float * y = (device const float *) (src1 + offset1); | ||
| 8203 | |||
| 8204 | const int nb = args.ne00/QK_K; | ||
| 8205 | const int ns01 = args.nb01/args.nb00; | ||
| 8206 | |||
| 8207 | const short ix = tiisg/16; // 0 or 1 | ||
| 8208 | const short it = tiisg%16; // 0...15 | ||
| 8209 | const short ib = it/2; | ||
| 8210 | const short il = it%2; | ||
| 8211 | |||
| 8212 | shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; | ||
| 8213 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8214 | |||
| 8215 | float4 yl[4]; | ||
| 8216 | float sumf[NR0]={0.f}; | ||
| 8217 | |||
| 8218 | device const float * yb = y + ix * QK_K + ib * 32 + il * 8; | ||
| 8219 | |||
| 8220 | uint32_t aux32[2]; | ||
| 8221 | thread const uint8_t * q8 = (thread const uint8_t *)aux32; | ||
| 8222 | |||
| 8223 | float4 qf1, qf2; | ||
| 8224 | |||
| 8225 | // [TAG_MUL_MV_WEIRD] | ||
| 8226 | for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { | ||
| 8227 | device const float4 * y4 = (device const float4 *)yb; | ||
| 8228 | yl[0] = y4[0]; | ||
| 8229 | yl[1] = y4[4]; | ||
| 8230 | yl[2] = y4[1]; | ||
| 8231 | yl[3] = y4[5]; | ||
| 8232 | |||
| 8233 | for (short row = 0; row < NR0; ++row) { | ||
| 8234 | device const block_iq4_xs & xb = x[row*ns01 + ibl]; | ||
| 8235 | device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); | ||
| 8236 | |||
| 8237 | float4 acc1 = {0.f}, acc2 = {0.f}; | ||
| 8238 | |||
| 8239 | aux32[0] = (q4[0] ) & 0x0f0f0f0f; | ||
| 8240 | aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; | ||
| 8241 | qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; | ||
| 8242 | qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; | ||
| 8243 | acc1 += yl[0] * qf1; | ||
| 8244 | acc2 += yl[1] * qf2; | ||
| 8245 | |||
| 8246 | aux32[0] = (q4[1] ) & 0x0f0f0f0f; | ||
| 8247 | aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; | ||
| 8248 | qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; | ||
| 8249 | qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; | ||
| 8250 | acc1 += yl[2] * qf1; | ||
| 8251 | acc2 += yl[3] * qf2; | ||
| 8252 | |||
| 8253 | acc1 += acc2; | ||
| 8254 | |||
| 8255 | const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; | ||
| 8256 | sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); | ||
| 8257 | } | ||
| 8258 | |||
| 8259 | yb += 2 * QK_K; | ||
| 8260 | } | ||
| 8261 | |||
| 8262 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 8263 | |||
| 8264 | for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { | ||
| 8265 | float sum_all = simd_sum(sumf[row]); | ||
| 8266 | if (tiisg == 0) { | ||
| 8267 | dst_f32[first_row + row] = sum_all; | ||
| 8268 | } | ||
| 8269 | } | ||
| 8270 | } | ||
| 8271 | |||
| 8272 | [[host_name("kernel_mul_mv_iq4_xs_f32")]] | ||
| 8273 | kernel void kernel_mul_mv_iq4_xs_f32( | ||
| 8274 | constant ggml_metal_kargs_mul_mv & args, | ||
| 8275 | device const char * src0, | ||
| 8276 | device const char * src1, | ||
| 8277 | device char * dst, | ||
| 8278 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 8279 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8280 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 8281 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 8282 | |||
| 8283 | kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 8284 | } | ||
| 8285 | |||
| 8286 | template<int NR0, typename args_t> | ||
| 8287 | void kernel_mul_mv_mxfp4_f32_impl( | ||
| 8288 | args_t args, | ||
| 8289 | device const char * src0, | ||
| 8290 | device const char * src1, | ||
| 8291 | device char * dst, | ||
| 8292 | threadgroup char * shmem, | ||
| 8293 | uint3 tgpig, | ||
| 8294 | ushort tiisg, | ||
| 8295 | ushort sgitg) { | ||
| 8296 | const short NSG = FC_mul_mv_nsg; | ||
| 8297 | |||
| 8298 | threadgroup float * shmem_f32 = (threadgroup float *) shmem; | ||
| 8299 | |||
| 8300 | const int r0 = tgpig.x; | ||
| 8301 | const int r1 = tgpig.y; | ||
| 8302 | const int im = tgpig.z; | ||
| 8303 | |||
| 8304 | const int first_row = (r0 * NSG + sgitg) * NR0; | ||
| 8305 | |||
| 8306 | const uint i12 = im%args.ne12; | ||
| 8307 | const uint i13 = im/args.ne12; | ||
| 8308 | |||
| 8309 | const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 8310 | const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; | ||
| 8311 | |||
| 8312 | device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); | ||
| 8313 | device const float * y = (device const float *) (src1 + offset1); | ||
| 8314 | |||
| 8315 | const int nb = args.ne00/QK_MXFP4; | ||
| 8316 | const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors | ||
| 8317 | |||
| 8318 | const short ix = tiisg/2; // 0...15 | ||
| 8319 | const short it = tiisg%2; // 0 or 1 | ||
| 8320 | |||
| 8321 | shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16]; | ||
| 8322 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8323 | |||
| 8324 | float4 yl[4]; | ||
| 8325 | float sumf[NR0]={0.f}; | ||
| 8326 | |||
| 8327 | device const float * yb = y + ix*QK_MXFP4 + it*8; | ||
| 8328 | |||
| 8329 | // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster | ||
| 8330 | // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] | ||
| 8331 | for (int ib = ix; ib < nb && ib < ns01; ib += 16) { | ||
| 8332 | device const float4 * y4 = (device const float4 *) yb; | ||
| 8333 | |||
| 8334 | yl[0] = y4[0]; | ||
| 8335 | yl[1] = y4[4]; | ||
| 8336 | yl[2] = y4[1]; | ||
| 8337 | yl[3] = y4[5]; | ||
| 8338 | |||
| 8339 | FOR_UNROLL (short row = 0; row < NR0; row++) { | ||
| 8340 | device const block_mxfp4 & xb = x[row*ns01 + ib]; | ||
| 8341 | device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); | ||
| 8342 | |||
| 8343 | float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); | ||
| 8344 | float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]); | ||
| 8345 | float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]); | ||
| 8346 | float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]); | ||
| 8347 | |||
| 8348 | acc1 = (acc1 + acc3) + (acc2 + acc4); | ||
| 8349 | |||
| 8350 | sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3])); | ||
| 8351 | } | ||
| 8352 | |||
| 8353 | yb += 16 * QK_MXFP4; | ||
| 8354 | } | ||
| 8355 | |||
| 8356 | device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||
| 8357 | |||
| 8358 | for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { | ||
| 8359 | float sum_all = simd_sum(sumf[row]); | ||
| 8360 | if (tiisg == 0) { | ||
| 8361 | dst_f32[first_row + row] = sum_all; | ||
| 8362 | } | ||
| 8363 | } | ||
| 8364 | } | ||
| 8365 | |||
| 8366 | [[host_name("kernel_mul_mv_mxfp4_f32")]] | ||
| 8367 | kernel void kernel_mul_mv_mxfp4_f32( | ||
| 8368 | constant ggml_metal_kargs_mul_mv & args, | ||
| 8369 | device const char * src0, | ||
| 8370 | device const char * src1, | ||
| 8371 | device char * dst, | ||
| 8372 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 8373 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8374 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 8375 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 8376 | |||
| 8377 | kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 8378 | } | ||
| 8379 | |||
| 8380 | template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> | ||
| 8381 | kernel void kernel_get_rows_q( | ||
| 8382 | constant ggml_metal_kargs_get_rows & args, | ||
| 8383 | device const void * src0, | ||
| 8384 | device const void * src1, | ||
| 8385 | device void * dst, | ||
| 8386 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8387 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 8388 | ushort3 ntg [[threads_per_threadgroup]]) { | ||
| 8389 | const int32_t iw0 = tgpig.x/args.ne10; | ||
| 8390 | const int32_t i10 = tgpig.x%args.ne10; | ||
| 8391 | const int32_t i11 = tgpig.y; | ||
| 8392 | const int32_t i12 = tgpig.z; | ||
| 8393 | |||
| 8394 | const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; | ||
| 8395 | |||
| 8396 | const int32_t i02 = i11; | ||
| 8397 | const int32_t i03 = i12; | ||
| 8398 | |||
| 8399 | auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); | ||
| 8400 | auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); | ||
| 8401 | |||
| 8402 | for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { | ||
| 8403 | float4x4 temp; | ||
| 8404 | dequantize_func(psrc + ind/nl, ind%nl, temp); | ||
| 8405 | pdst[ind] = temp; | ||
| 8406 | |||
| 8407 | break; | ||
| 8408 | } | ||
| 8409 | } | ||
| 8410 | |||
| 8411 | template<typename T0, typename T> | ||
| 8412 | kernel void kernel_get_rows_f( | ||
| 8413 | constant ggml_metal_kargs_get_rows & args, | ||
| 8414 | device const void * src0, | ||
| 8415 | device const void * src1, | ||
| 8416 | device void * dst, | ||
| 8417 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8418 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 8419 | ushort3 ntg [[threads_per_threadgroup]]) { | ||
| 8420 | const int32_t iw0 = tgpig.x/args.ne10; | ||
| 8421 | const int32_t i10 = tgpig.x%args.ne10; | ||
| 8422 | const int32_t i11 = tgpig.y; | ||
| 8423 | const int32_t i12 = tgpig.z; | ||
| 8424 | |||
| 8425 | const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; | ||
| 8426 | |||
| 8427 | const int32_t i02 = i11; | ||
| 8428 | const int32_t i03 = i12; | ||
| 8429 | |||
| 8430 | auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); | ||
| 8431 | auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); | ||
| 8432 | |||
| 8433 | for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { | ||
| 8434 | pdst[ind] = psrc[ind]; | ||
| 8435 | |||
| 8436 | break; | ||
| 8437 | } | ||
| 8438 | } | ||
| 8439 | |||
| 8440 | template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)> | ||
| 8441 | kernel void kernel_set_rows_q32( | ||
| 8442 | constant ggml_metal_kargs_set_rows & args, | ||
| 8443 | device const void * src0, | ||
| 8444 | device const void * src1, | ||
| 8445 | device float * dst, | ||
| 8446 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8447 | uint tiitg[[thread_index_in_threadgroup]], | ||
| 8448 | uint3 tptg [[threads_per_threadgroup]]) { | ||
| 8449 | const int32_t i03 = tgpig.z; | ||
| 8450 | const int32_t i02 = tgpig.y; | ||
| 8451 | |||
| 8452 | const int32_t i12 = i03%args.ne12; | ||
| 8453 | const int32_t i11 = i02%args.ne11; | ||
| 8454 | |||
| 8455 | const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x; | ||
| 8456 | if (i01 >= args.ne01) { | ||
| 8457 | return; | ||
| 8458 | } | ||
| 8459 | |||
| 8460 | const int32_t i10 = i01; | ||
| 8461 | const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; | ||
| 8462 | |||
| 8463 | device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); | ||
| 8464 | const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); | ||
| 8465 | |||
| 8466 | for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { | ||
| 8467 | quantize_func(src_row + 32*ind, dst_row[ind]); | ||
| 8468 | } | ||
| 8469 | } | ||
| 8470 | |||
| 8471 | template<typename T, typename TI> | ||
| 8472 | kernel void kernel_set_rows_f( | ||
| 8473 | constant ggml_metal_kargs_set_rows & args, | ||
| 8474 | device const void * src0, | ||
| 8475 | device const void * src1, | ||
| 8476 | device float * dst, | ||
| 8477 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8478 | uint tiitg[[thread_index_in_threadgroup]], | ||
| 8479 | uint3 tptg [[threads_per_threadgroup]]) { | ||
| 8480 | const int32_t i03 = tgpig.z; | ||
| 8481 | const int32_t i02 = tgpig.y; | ||
| 8482 | |||
| 8483 | const int32_t i12 = i03%args.ne12; | ||
| 8484 | const int32_t i11 = i02%args.ne11; | ||
| 8485 | |||
| 8486 | const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x; | ||
| 8487 | if (i01 >= args.ne01) { | ||
| 8488 | return; | ||
| 8489 | } | ||
| 8490 | |||
| 8491 | const int32_t i10 = i01; | ||
| 8492 | const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0]; | ||
| 8493 | |||
| 8494 | device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3); | ||
| 8495 | const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03); | ||
| 8496 | |||
| 8497 | for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) { | ||
| 8498 | dst_row[ind] = (T) src_row[ind]; | ||
| 8499 | } | ||
| 8500 | } | ||
| 8501 | |||
| 8502 | kernel void kernel_diag_f32( | ||
| 8503 | constant ggml_metal_kargs_diag & args, | ||
| 8504 | device const char * src0, | ||
| 8505 | device char * dst, | ||
| 8506 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8507 | ushort tiitg[[thread_index_in_threadgroup]]) { | ||
| 8508 | constexpr short NW = N_SIMDWIDTH; | ||
| 8509 | |||
| 8510 | const int32_t i3 = tgpig.z; | ||
| 8511 | const int32_t i2 = tgpig.y; | ||
| 8512 | const int32_t i1 = tgpig.x; | ||
| 8513 | |||
| 8514 | device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03); | ||
| 8515 | device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3); | ||
| 8516 | |||
| 8517 | for (int i0 = tiitg; i0 < args.ne0; i0 += NW) { | ||
| 8518 | dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f; | ||
| 8519 | } | ||
| 8520 | } | ||
| 8521 | |||
| 8522 | constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; | ||
| 8523 | constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; | ||
| 8524 | |||
| 8525 | // each block_q contains 16*nl weights | ||
| 8526 | template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4> | ||
| 8527 | kernel void kernel_mul_mm( | ||
| 8528 | constant ggml_metal_kargs_mul_mm & args, | ||
| 8529 | device const char * src0, | ||
| 8530 | device const char * src1, | ||
| 8531 | device char * dst, | ||
| 8532 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 8533 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8534 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 8535 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 8536 | |||
| 8537 | threadgroup S0 * sa = (threadgroup S0 *)(shmem); | ||
| 8538 | threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); | ||
| 8539 | |||
| 8540 | threadgroup float * sc = (threadgroup float *)(shmem); | ||
| 8541 | |||
| 8542 | constexpr int NR0 = 64; | ||
| 8543 | constexpr int NR1 = 32; | ||
| 8544 | |||
| 8545 | constexpr int NK = 32; | ||
| 8546 | constexpr int NL0 = NK/16; | ||
| 8547 | constexpr int NL1 = NK/8; | ||
| 8548 | |||
| 8549 | const int im = tgpig.z; | ||
| 8550 | const int r0 = tgpig.y*NR0; | ||
| 8551 | const int r1 = tgpig.x*NR1; | ||
| 8552 | |||
| 8553 | // if this block is of 64x32 shape or smaller | ||
| 8554 | const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; | ||
| 8555 | const short nr1 = (args.ne1 - r1 < NR1) ? (args.ne1 - r1) : NR1; | ||
| 8556 | |||
| 8557 | // a thread shouldn't load data outside of the matrix | ||
| 8558 | const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63 | ||
| 8559 | const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31 | ||
| 8560 | |||
| 8561 | const short il0 = (tiitg % NL0); | ||
| 8562 | |||
| 8563 | short il = il0; | ||
| 8564 | |||
| 8565 | const int i12 = im%args.ne12; | ||
| 8566 | const int i13 = im/args.ne12; | ||
| 8567 | |||
| 8568 | const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||
| 8569 | const short offset1 = il0/nl; | ||
| 8570 | |||
| 8571 | device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; | ||
| 8572 | |||
| 8573 | const short iy = 8*(tiitg % NL1); | ||
| 8574 | |||
| 8575 | device const T1 * y = (device const T1 *)(src1 | ||
| 8576 | + args.nb13*i13 | ||
| 8577 | + args.nb12*i12 | ||
| 8578 | + args.nb11*(r1 + lr1) | ||
| 8579 | + args.nb10*iy); | ||
| 8580 | |||
| 8581 | #ifndef GGML_METAL_HAS_TENSOR | ||
| 8582 | S0_8x8 ma[4]; | ||
| 8583 | S1_8x8 mb[2]; | ||
| 8584 | |||
| 8585 | simdgroup_float8x8 mc[8]; | ||
| 8586 | |||
| 8587 | for (short i = 0; i < 8; i++){ | ||
| 8588 | mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f); | ||
| 8589 | } | ||
| 8590 | #else | ||
| 8591 | auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0)); | ||
| 8592 | auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK )); | ||
| 8593 | |||
| 8594 | mpp::tensor_ops::matmul2d< | ||
| 8595 | mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), | ||
| 8596 | execution_simdgroups<4>> mm; | ||
| 8597 | |||
| 8598 | auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); | ||
| 8599 | #endif | ||
| 8600 | |||
| 8601 | for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { | ||
| 8602 | #ifndef GGML_METAL_HAS_TENSOR | ||
| 8603 | // load data and store to threadgroup memory | ||
| 8604 | if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) { | ||
| 8605 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8606 | |||
| 8607 | // no need for dequantization | ||
| 8608 | for (short i = 0; i < 16; i++) { | ||
| 8609 | const short sx = 2*il0 + i/8; | ||
| 8610 | const short sy = (tiitg/NL0)/8; | ||
| 8611 | |||
| 8612 | //const short lx = i%8; | ||
| 8613 | //const short ly = (tiitg/NL0)%8; | ||
| 8614 | const short lx = (tiitg/NL0)%8; | ||
| 8615 | const short ly = i%8; | ||
| 8616 | |||
| 8617 | const short ib = 8*sx + sy; | ||
| 8618 | |||
| 8619 | *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; | ||
| 8620 | } | ||
| 8621 | } else { | ||
| 8622 | S0_4x4 temp_a; | ||
| 8623 | dequantize_func(x, il, temp_a); | ||
| 8624 | |||
| 8625 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8626 | |||
| 8627 | FOR_UNROLL (short i = 0; i < 16; i++) { | ||
| 8628 | const short sx = 2*il0 + i/8; | ||
| 8629 | const short sy = (tiitg/NL0)/8; | ||
| 8630 | |||
| 8631 | //const short lx = i%8; | ||
| 8632 | //const short ly = (tiitg/NL0)%8; | ||
| 8633 | const short lx = (tiitg/NL0)%8; | ||
| 8634 | const short ly = i%8; | ||
| 8635 | |||
| 8636 | const short ib = 8*sx + sy; | ||
| 8637 | |||
| 8638 | // NOTE: this is massively slower.. WTF? | ||
| 8639 | //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4]; | ||
| 8640 | |||
| 8641 | *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; | ||
| 8642 | } | ||
| 8643 | } | ||
| 8644 | |||
| 8645 | if (FC_mul_mm_bc_inp) { | ||
| 8646 | for (short i = 0; i < 8; ++i) { | ||
| 8647 | const short sx = (tiitg%NL1); | ||
| 8648 | const short sy = (tiitg/NL1)/8; | ||
| 8649 | |||
| 8650 | const short lx = i; | ||
| 8651 | const short ly = (tiitg/NL1)%8; | ||
| 8652 | //const short lx = (tiitg/NL1)%8; | ||
| 8653 | //const short ly = i; | ||
| 8654 | |||
| 8655 | const short ib = 4*sx + sy; | ||
| 8656 | |||
| 8657 | *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; | ||
| 8658 | } | ||
| 8659 | } else { | ||
| 8660 | const short sx = (tiitg%NL1); | ||
| 8661 | const short sy = (tiitg/NL1)/8; | ||
| 8662 | |||
| 8663 | const short dx = sx; | ||
| 8664 | const short dy = sy; | ||
| 8665 | |||
| 8666 | const short ly = (tiitg/NL1)%8; | ||
| 8667 | |||
| 8668 | const short ib = 4*sx + sy; | ||
| 8669 | |||
| 8670 | *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); | ||
| 8671 | } | ||
| 8672 | #else | ||
| 8673 | // load data and store to threadgroup memory | ||
| 8674 | if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) { | ||
| 8675 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8676 | |||
| 8677 | // no need for dequantization | ||
| 8678 | for (short i = 0; i < 16; i++) { | ||
| 8679 | const short sx = 2*il0 + i/8; | ||
| 8680 | const short sy = (tiitg/NL0)/8; | ||
| 8681 | |||
| 8682 | const short lx = i%8; | ||
| 8683 | const short ly = (tiitg/NL0)%8; | ||
| 8684 | //const short lx = (tiitg/NL0)%8; | ||
| 8685 | //const short ly = i%8; | ||
| 8686 | |||
| 8687 | *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; | ||
| 8688 | } | ||
| 8689 | } else { | ||
| 8690 | S0_4x4 temp_a; | ||
| 8691 | dequantize_func(x, il, temp_a); | ||
| 8692 | |||
| 8693 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8694 | |||
| 8695 | FOR_UNROLL (short i = 0; i < 16; i++) { | ||
| 8696 | const short sx = 2*il0 + i/8; | ||
| 8697 | const short sy = (tiitg/NL0)/8; | ||
| 8698 | |||
| 8699 | const short lx = i%8; | ||
| 8700 | const short ly = (tiitg/NL0)%8; | ||
| 8701 | //const short lx = (tiitg/NL0)%8; | ||
| 8702 | //const short ly = i%8; | ||
| 8703 | |||
| 8704 | *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; | ||
| 8705 | } | ||
| 8706 | } | ||
| 8707 | |||
| 8708 | if (FC_mul_mm_bc_inp) { | ||
| 8709 | for (short i = 0; i < 8; ++i) { | ||
| 8710 | const short sx = (tiitg%NL1); | ||
| 8711 | const short sy = (tiitg/NL1)/8; | ||
| 8712 | |||
| 8713 | const short lx = i; | ||
| 8714 | const short ly = (tiitg/NL1)%8; | ||
| 8715 | //const short lx = (tiitg/NL1)%8; | ||
| 8716 | //const short ly = i; | ||
| 8717 | |||
| 8718 | *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; | ||
| 8719 | } | ||
| 8720 | } else { | ||
| 8721 | const short sx = (tiitg%NL1); | ||
| 8722 | const short sy = (tiitg/NL1)/8; | ||
| 8723 | |||
| 8724 | //const short lx = i; | ||
| 8725 | const short ly = (tiitg/NL1)%8; | ||
| 8726 | //const short lx = (tiitg/NL1)%8; | ||
| 8727 | //const short ly = i; | ||
| 8728 | |||
| 8729 | *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); | ||
| 8730 | } | ||
| 8731 | #endif | ||
| 8732 | |||
| 8733 | il = (il + 2 < nl) ? il + 2 : il % 2; | ||
| 8734 | x = (il < 2) ? x + (2 + nl - 1)/nl : x; | ||
| 8735 | |||
| 8736 | y += NK; | ||
| 8737 | |||
| 8738 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8739 | |||
| 8740 | #ifndef GGML_METAL_HAS_TENSOR | ||
| 8741 | // load matrices from threadgroup memory and conduct outer products | ||
| 8742 | threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); | ||
| 8743 | threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); | ||
| 8744 | |||
| 8745 | FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { | ||
| 8746 | simdgroup_barrier(mem_flags::mem_none); | ||
| 8747 | |||
| 8748 | FOR_UNROLL (short i = 0; i < 4; i++) { | ||
| 8749 | simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); | ||
| 8750 | } | ||
| 8751 | |||
| 8752 | simdgroup_barrier(mem_flags::mem_none); | ||
| 8753 | |||
| 8754 | FOR_UNROLL (short i = 0; i < 2; i++) { | ||
| 8755 | simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); | ||
| 8756 | } | ||
| 8757 | |||
| 8758 | simdgroup_barrier(mem_flags::mem_none); | ||
| 8759 | |||
| 8760 | FOR_UNROLL (short i = 0; i < 8; i++){ | ||
| 8761 | simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); | ||
| 8762 | } | ||
| 8763 | |||
| 8764 | lsma += 8*64; | ||
| 8765 | lsmb += 4*64; | ||
| 8766 | } | ||
| 8767 | #else | ||
| 8768 | auto sA = tA.slice(0, 0); | ||
| 8769 | auto sB = tB.slice(0, 0); | ||
| 8770 | |||
| 8771 | mm.run(sB, sA, cT); | ||
| 8772 | #endif | ||
| 8773 | } | ||
| 8774 | |||
| 8775 | if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { | ||
| 8776 | // if no bounds checks on the output are needed, we can directly write to device memory | ||
| 8777 | #ifdef GGML_METAL_HAS_TENSOR | ||
| 8778 | device float * C = (device float *) dst + | ||
| 8779 | r0 + \ | ||
| 8780 | r1 * args.ne0 + im*args.ne1*args.ne0; | ||
| 8781 | |||
| 8782 | auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1)); | ||
| 8783 | cT.store(tC); | ||
| 8784 | #else | ||
| 8785 | device float * C = (device float *) dst + | ||
| 8786 | (r0 + 32*(sgitg & 1)) + \ | ||
| 8787 | (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; | ||
| 8788 | |||
| 8789 | for (short i = 0; i < 8; i++) { | ||
| 8790 | simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); | ||
| 8791 | } | ||
| 8792 | #endif | ||
| 8793 | } else { | ||
| 8794 | // block is smaller than 64x32, we should avoid writing data outside of the matrix | ||
| 8795 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8796 | |||
| 8797 | threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; | ||
| 8798 | |||
| 8799 | #ifdef GGML_METAL_HAS_TENSOR | ||
| 8800 | auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1)); | ||
| 8801 | cT.store(tC); | ||
| 8802 | #else | ||
| 8803 | for (short i = 0; i < 8; i++) { | ||
| 8804 | simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); | ||
| 8805 | } | ||
| 8806 | #endif | ||
| 8807 | |||
| 8808 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8809 | |||
| 8810 | if (sgitg == 0) { | ||
| 8811 | for (int j = tiitg; j < nr1; j += NR1) { | ||
| 8812 | device float * D = (device float *) dst + r0 + (r1 + j)*args.ne0 + im*args.ne1*args.ne0; | ||
| 8813 | device float4 * D4 = (device float4 *) D; | ||
| 8814 | |||
| 8815 | threadgroup float * C = temp_str + (j*NR0); | ||
| 8816 | threadgroup float4 * C4 = (threadgroup float4 *) C; | ||
| 8817 | |||
| 8818 | int i = 0; | ||
| 8819 | for (; i < nr0/4; i++) { | ||
| 8820 | *(D4 + i) = *(C4 + i); | ||
| 8821 | } | ||
| 8822 | |||
| 8823 | i *= 4; | ||
| 8824 | for (; i < nr0; i++) { | ||
| 8825 | *(D + i) = *(C + i); | ||
| 8826 | } | ||
| 8827 | } | ||
| 8828 | } | ||
| 8829 | } | ||
| 8830 | } | ||
| 8831 | |||
| 8832 | template<short ne20> // n_expert_used | ||
| 8833 | kernel void kernel_mul_mm_id_map0( | ||
| 8834 | constant ggml_metal_kargs_mul_mm_id_map0 & args, | ||
| 8835 | device const char * src2, | ||
| 8836 | device char * htpe, | ||
| 8837 | device char * hids, | ||
| 8838 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 8839 | ushort tpitg[[thread_position_in_threadgroup]], | ||
| 8840 | ushort ntg[[threads_per_threadgroup]]) { | ||
| 8841 | const short ide = tpitg; // expert id | ||
| 8842 | |||
| 8843 | uint32_t n_all = 0; | ||
| 8844 | |||
| 8845 | device int32_t * ids_i32 = (device int32_t *) hids + ide*args.ne21; | ||
| 8846 | |||
| 8847 | for (int i21 = 0; i21 < args.ne21; i21 += ntg) { // n_tokens | ||
| 8848 | if (i21 + tpitg < args.ne21) { | ||
| 8849 | device const int32_t * src2_i32 = (device const int32_t *) (src2 + (i21 + tpitg)*args.nb21); | ||
| 8850 | |||
| 8851 | threadgroup uint16_t * sids = (threadgroup uint16_t *) shmem + tpitg*ne20; | ||
| 8852 | |||
| 8853 | #pragma unroll(ne20) | ||
| 8854 | for (short i20 = 0; i20 < ne20; i20++) { | ||
| 8855 | sids[i20] = src2_i32[i20]; | ||
| 8856 | } | ||
| 8857 | } | ||
| 8858 | |||
| 8859 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8860 | |||
| 8861 | for (short t = 0; t < ntg; t++) { | ||
| 8862 | if (i21 + t >= args.ne21) { | ||
| 8863 | break; | ||
| 8864 | } | ||
| 8865 | |||
| 8866 | threadgroup const uint16_t * sids = (threadgroup const uint16_t *) shmem + t*ne20; | ||
| 8867 | |||
| 8868 | short sel = 0; | ||
| 8869 | #pragma unroll(ne20) | ||
| 8870 | for (short i20 = 0; i20 < ne20; i20++) { | ||
| 8871 | sel += (sids[i20] == ide)*(i20 + 1); | ||
| 8872 | } | ||
| 8873 | |||
| 8874 | ids_i32[n_all] = (i21 + t)*ne20 + sel - 1; | ||
| 8875 | |||
| 8876 | n_all += sel > 0; | ||
| 8877 | } | ||
| 8878 | |||
| 8879 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8880 | } | ||
| 8881 | |||
| 8882 | device uint32_t * tpe_u32 = (device uint32_t *) (htpe); | ||
| 8883 | tpe_u32[ide] = n_all; | ||
| 8884 | } | ||
| 8885 | |||
| 8886 | typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; | ||
| 8887 | |||
| 8888 | template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; | ||
| 8889 | template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; | ||
| 8890 | template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; | ||
| 8891 | template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>; | ||
| 8892 | template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; | ||
| 8893 | template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; | ||
| 8894 | template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; | ||
| 8895 | template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; | ||
| 8896 | |||
| 8897 | template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4> | ||
| 8898 | kernel void kernel_mul_mm_id( | ||
| 8899 | constant ggml_metal_kargs_mul_mm_id & args, | ||
| 8900 | device const char * src0, | ||
| 8901 | device const char * src1, | ||
| 8902 | device const char * htpe, | ||
| 8903 | device const char * hids, | ||
| 8904 | device char * dst, | ||
| 8905 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 8906 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 8907 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 8908 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 8909 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 8910 | threadgroup S0 * sa = (threadgroup S0 *)(shmem); | ||
| 8911 | threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); | ||
| 8912 | |||
| 8913 | threadgroup float * sc = (threadgroup float *)(shmem); | ||
| 8914 | |||
| 8915 | constexpr int NR0 = 64; | ||
| 8916 | constexpr int NR1 = 32; | ||
| 8917 | |||
| 8918 | constexpr int NK = 32; | ||
| 8919 | constexpr int NL0 = NK/16; | ||
| 8920 | constexpr int NL1 = NK/8; | ||
| 8921 | |||
| 8922 | const int im = tgpig.z; // expert | ||
| 8923 | const int r0 = tgpig.y*NR0; | ||
| 8924 | const int r1 = tgpig.x*NR1; | ||
| 8925 | |||
| 8926 | device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe); | ||
| 8927 | device const int32_t * ids_i32 = (device const int32_t *) (hids); | ||
| 8928 | |||
| 8929 | const int32_t neh1 = tpe_u32[im]; | ||
| 8930 | |||
| 8931 | if (r1 >= neh1) { | ||
| 8932 | return; | ||
| 8933 | } | ||
| 8934 | |||
| 8935 | // if this block is of 64x32 shape or smaller | ||
| 8936 | const short nr0 = (args.ne0 - r0 < NR0) ? (args.ne0 - r0) : NR0; | ||
| 8937 | const short nr1 = ( neh1 - r1 < NR1) ? ( neh1 - r1) : NR1; | ||
| 8938 | |||
| 8939 | // a thread shouldn't load data outside of the matrix | ||
| 8940 | const short lr0 = ((short)tiitg/NL0) < nr0 ? ((short)tiitg/NL0) : nr0 - 1; // 0 .. 63 | ||
| 8941 | const short lr1 = ((short)tiitg/NL1) < nr1 ? ((short)tiitg/NL1) : nr1 - 1; // 0 .. 31 | ||
| 8942 | |||
| 8943 | const short il0 = (tiitg % NL0); | ||
| 8944 | |||
| 8945 | short il = il0; | ||
| 8946 | |||
| 8947 | const int id = ids_i32[im*args.ne21 + r1 + lr1]; | ||
| 8948 | |||
| 8949 | const short i11 = (id % args.ne20) % args.ne11; | ||
| 8950 | const short i12 = (id / args.ne20); | ||
| 8951 | const short i13 = 0; | ||
| 8952 | |||
| 8953 | const uint64_t offset0 = im*args.nb02 + i13*args.nb03; | ||
| 8954 | const short offset1 = il0/nl; | ||
| 8955 | |||
| 8956 | device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; | ||
| 8957 | |||
| 8958 | const short iy = 8*(tiitg % NL1); | ||
| 8959 | |||
| 8960 | device const T1 * y = (device const T1 *)(src1 | ||
| 8961 | + args.nb13*i13 | ||
| 8962 | + args.nb12*i12 | ||
| 8963 | + args.nb11*i11 | ||
| 8964 | + args.nb10*iy); | ||
| 8965 | |||
| 8966 | #ifndef GGML_METAL_HAS_TENSOR | ||
| 8967 | S0_8x8 ma[4]; | ||
| 8968 | S1_8x8 mb[2]; | ||
| 8969 | |||
| 8970 | simdgroup_float8x8 mc[8]; | ||
| 8971 | |||
| 8972 | for (short i = 0; i < 8; i++){ | ||
| 8973 | mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f); | ||
| 8974 | } | ||
| 8975 | #else | ||
| 8976 | auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0)); | ||
| 8977 | auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK )); | ||
| 8978 | |||
| 8979 | mpp::tensor_ops::matmul2d< | ||
| 8980 | mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), | ||
| 8981 | execution_simdgroups<4>> mm; | ||
| 8982 | |||
| 8983 | auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); | ||
| 8984 | #endif | ||
| 8985 | |||
| 8986 | for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { | ||
| 8987 | #ifndef GGML_METAL_HAS_TENSOR | ||
| 8988 | // load data and store to threadgroup memory | ||
| 8989 | if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) { | ||
| 8990 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 8991 | |||
| 8992 | // no need for dequantization | ||
| 8993 | for (short i = 0; i < 16; i++) { | ||
| 8994 | const short sx = 2*il0 + i/8; | ||
| 8995 | const short sy = (tiitg/NL0)/8; | ||
| 8996 | |||
| 8997 | //const short lx = i%8; | ||
| 8998 | //const short ly = (tiitg/NL0)%8; | ||
| 8999 | const short lx = (tiitg/NL0)%8; | ||
| 9000 | const short ly = i%8; | ||
| 9001 | |||
| 9002 | const short ib = 8*sx + sy; | ||
| 9003 | |||
| 9004 | *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; | ||
| 9005 | } | ||
| 9006 | } else { | ||
| 9007 | S0_4x4 temp_a; | ||
| 9008 | dequantize_func(x, il, temp_a); | ||
| 9009 | |||
| 9010 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 9011 | |||
| 9012 | FOR_UNROLL (short i = 0; i < 16; i++) { | ||
| 9013 | const short sx = 2*il0 + i/8; | ||
| 9014 | const short sy = (tiitg/NL0)/8; | ||
| 9015 | |||
| 9016 | //const short lx = i%8; | ||
| 9017 | //const short ly = (tiitg/NL0)%8; | ||
| 9018 | const short lx = (tiitg/NL0)%8; | ||
| 9019 | const short ly = i%8; | ||
| 9020 | |||
| 9021 | const short ib = 8*sx + sy; | ||
| 9022 | |||
| 9023 | // NOTE: this is massively slower.. WTF? | ||
| 9024 | //sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4]; | ||
| 9025 | |||
| 9026 | *(sa + 64*ib + 8*ly + lx) = temp_a[i/4][i%4]; | ||
| 9027 | } | ||
| 9028 | } | ||
| 9029 | |||
| 9030 | if (FC_mul_mm_bc_inp) { | ||
| 9031 | for (short i = 0; i < 8; ++i) { | ||
| 9032 | const short sx = (tiitg%NL1); | ||
| 9033 | const short sy = (tiitg/NL1)/8; | ||
| 9034 | |||
| 9035 | const short lx = i; | ||
| 9036 | const short ly = (tiitg/NL1)%8; | ||
| 9037 | //const short lx = (tiitg/NL1)%8; | ||
| 9038 | //const short ly = i; | ||
| 9039 | |||
| 9040 | const short ib = 4*sx + sy; | ||
| 9041 | |||
| 9042 | *(sb + 64*ib + 8*ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; | ||
| 9043 | } | ||
| 9044 | } else { | ||
| 9045 | const short sx = (tiitg%NL1); | ||
| 9046 | const short sy = (tiitg/NL1)/8; | ||
| 9047 | |||
| 9048 | const short dx = sx; | ||
| 9049 | const short dy = sy; | ||
| 9050 | |||
| 9051 | const short ly = (tiitg/NL1)%8; | ||
| 9052 | |||
| 9053 | const short ib = 4*sx + sy; | ||
| 9054 | |||
| 9055 | *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); | ||
| 9056 | } | ||
| 9057 | #else | ||
| 9058 | // load data and store to threadgroup memory | ||
| 9059 | if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) { | ||
| 9060 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 9061 | |||
| 9062 | // no need for dequantization | ||
| 9063 | for (short i = 0; i < 16; i++) { | ||
| 9064 | const short sx = 2*il0 + i/8; | ||
| 9065 | const short sy = (tiitg/NL0)/8; | ||
| 9066 | |||
| 9067 | const short lx = i%8; | ||
| 9068 | const short ly = (tiitg/NL0)%8; | ||
| 9069 | //const short lx = (tiitg/NL0)%8; | ||
| 9070 | //const short ly = i%8; | ||
| 9071 | |||
| 9072 | *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; | ||
| 9073 | } | ||
| 9074 | } else { | ||
| 9075 | S0_4x4 temp_a; | ||
| 9076 | dequantize_func(x, il, temp_a); | ||
| 9077 | |||
| 9078 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 9079 | |||
| 9080 | FOR_UNROLL (short i = 0; i < 16; i++) { | ||
| 9081 | const short sx = 2*il0 + i/8; | ||
| 9082 | const short sy = (tiitg/NL0)/8; | ||
| 9083 | |||
| 9084 | const short lx = i%8; | ||
| 9085 | const short ly = (tiitg/NL0)%8; | ||
| 9086 | //const short lx = (tiitg/NL0)%8; | ||
| 9087 | //const short ly = i%8; | ||
| 9088 | |||
| 9089 | *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; | ||
| 9090 | } | ||
| 9091 | } | ||
| 9092 | |||
| 9093 | if (FC_mul_mm_bc_inp) { | ||
| 9094 | for (short i = 0; i < 8; ++i) { | ||
| 9095 | const short sx = (tiitg%NL1); | ||
| 9096 | const short sy = (tiitg/NL1)/8; | ||
| 9097 | |||
| 9098 | const short lx = i; | ||
| 9099 | const short ly = (tiitg/NL1)%8; | ||
| 9100 | //const short lx = (tiitg/NL1)%8; | ||
| 9101 | //const short ly = i; | ||
| 9102 | |||
| 9103 | *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; | ||
| 9104 | } | ||
| 9105 | } else { | ||
| 9106 | const short sx = (tiitg%NL1); | ||
| 9107 | const short sy = (tiitg/NL1)/8; | ||
| 9108 | |||
| 9109 | //const short lx = i; | ||
| 9110 | const short ly = (tiitg/NL1)%8; | ||
| 9111 | //const short lx = (tiitg/NL1)%8; | ||
| 9112 | //const short ly = i; | ||
| 9113 | |||
| 9114 | *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); | ||
| 9115 | } | ||
| 9116 | #endif | ||
| 9117 | |||
| 9118 | il = (il + 2 < nl) ? il + 2 : il % 2; | ||
| 9119 | x = (il < 2) ? x + (2 + nl - 1)/nl : x; | ||
| 9120 | |||
| 9121 | y += NK; | ||
| 9122 | |||
| 9123 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 9124 | |||
| 9125 | #ifndef GGML_METAL_HAS_TENSOR | ||
| 9126 | // load matrices from threadgroup memory and conduct outer products | ||
| 9127 | threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); | ||
| 9128 | threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); | ||
| 9129 | |||
| 9130 | FOR_UNROLL (short ik = 0; ik < NK/8; ik++) { | ||
| 9131 | simdgroup_barrier(mem_flags::mem_none); | ||
| 9132 | |||
| 9133 | FOR_UNROLL (short i = 0; i < 4; i++) { | ||
| 9134 | simdgroup_load(ma[i], lsma + 64*i, 8, 0, false); | ||
| 9135 | } | ||
| 9136 | |||
| 9137 | simdgroup_barrier(mem_flags::mem_none); | ||
| 9138 | |||
| 9139 | FOR_UNROLL (short i = 0; i < 2; i++) { | ||
| 9140 | simdgroup_load(mb[i], lsmb + 64*i, 8, 0, false); | ||
| 9141 | } | ||
| 9142 | |||
| 9143 | simdgroup_barrier(mem_flags::mem_none); | ||
| 9144 | |||
| 9145 | FOR_UNROLL (short i = 0; i < 8; i++){ | ||
| 9146 | simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); | ||
| 9147 | } | ||
| 9148 | |||
| 9149 | lsma += 8*64; | ||
| 9150 | lsmb += 4*64; | ||
| 9151 | } | ||
| 9152 | #else | ||
| 9153 | auto sA = tA.slice(0, 0); | ||
| 9154 | auto sB = tB.slice(0, 0); | ||
| 9155 | |||
| 9156 | mm.run(sB, sA, cT); | ||
| 9157 | #endif | ||
| 9158 | } | ||
| 9159 | |||
| 9160 | // block is smaller than 64x32, we should avoid writing data outside of the matrix | ||
| 9161 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 9162 | |||
| 9163 | #ifdef GGML_METAL_HAS_TENSOR | ||
| 9164 | auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1)); | ||
| 9165 | cT.store(tC); | ||
| 9166 | #else | ||
| 9167 | threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; | ||
| 9168 | |||
| 9169 | for (short i = 0; i < 8; i++) { | ||
| 9170 | simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); | ||
| 9171 | } | ||
| 9172 | #endif | ||
| 9173 | |||
| 9174 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 9175 | |||
| 9176 | for (short j = sgitg; j < nr1; j += 4) { | ||
| 9177 | const int id = ids_i32[im*args.ne21 + r1 + j]; | ||
| 9178 | |||
| 9179 | const short ide = id % args.ne20; | ||
| 9180 | const short idt = id / args.ne20; | ||
| 9181 | |||
| 9182 | device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1*args.ne0; | ||
| 9183 | device float4 * D4 = (device float4 *) D; | ||
| 9184 | |||
| 9185 | threadgroup float * C = (threadgroup float *) shmem + j*NR0; | ||
| 9186 | threadgroup float4 * C4 = (threadgroup float4 *) C; | ||
| 9187 | |||
| 9188 | int i = tiisg; | ||
| 9189 | for (; i < nr0/4; i += 32) { | ||
| 9190 | *(D4 + i) = *(C4 + i); | ||
| 9191 | } | ||
| 9192 | |||
| 9193 | i = (4*(nr0/4)) + tiisg; | ||
| 9194 | for (; i < nr0; i += 32) { | ||
| 9195 | *(D + i) = *(C + i); | ||
| 9196 | } | ||
| 9197 | } | ||
| 9198 | } | ||
| 9199 | |||
| 9200 | #define QK_NL 16 | ||
| 9201 | |||
| 9202 | // | ||
| 9203 | // get rows | ||
| 9204 | // | ||
| 9205 | |||
| 9206 | typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t; | ||
| 9207 | |||
| 9208 | template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>; | ||
| 9209 | template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>; | ||
| 9210 | template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>; | ||
| 9211 | #if defined(GGML_METAL_HAS_BF16) | ||
| 9212 | template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>; | ||
| 9213 | #endif | ||
| 9214 | |||
| 9215 | typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t; | ||
| 9216 | |||
| 9217 | template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>; | ||
| 9218 | template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>; | ||
| 9219 | template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>; | ||
| 9220 | template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>; | ||
| 9221 | template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>; | ||
| 9222 | template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>; | ||
| 9223 | template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>; | ||
| 9224 | template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>; | ||
| 9225 | template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>; | ||
| 9226 | template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>; | ||
| 9227 | template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>; | ||
| 9228 | template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; | ||
| 9229 | template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>; | ||
| 9230 | template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; | ||
| 9231 | template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>; | ||
| 9232 | template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>; | ||
| 9233 | template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>; | ||
| 9234 | template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>; | ||
| 9235 | template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>; | ||
| 9236 | template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>; | ||
| 9237 | |||
| 9238 | // | ||
| 9239 | // set rows | ||
| 9240 | // | ||
| 9241 | |||
| 9242 | typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t; | ||
| 9243 | |||
| 9244 | template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>; | ||
| 9245 | template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>; | ||
| 9246 | template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>; | ||
| 9247 | template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>; | ||
| 9248 | #if defined(GGML_METAL_HAS_BF16) | ||
| 9249 | template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>; | ||
| 9250 | template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>; | ||
| 9251 | #endif | ||
| 9252 | |||
| 9253 | typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t; | ||
| 9254 | |||
| 9255 | template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>; | ||
| 9256 | template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>; | ||
| 9257 | template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>; | ||
| 9258 | template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>; | ||
| 9259 | template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>; | ||
| 9260 | template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>; | ||
| 9261 | template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>; | ||
| 9262 | template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>; | ||
| 9263 | template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>; | ||
| 9264 | template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>; | ||
| 9265 | template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>; | ||
| 9266 | template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>; | ||
| 9267 | |||
| 9268 | // | ||
| 9269 | // matrix-matrix multiplication | ||
| 9270 | // | ||
| 9271 | |||
| 9272 | typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_t; | ||
| 9273 | |||
| 9274 | template [[host_name("kernel_mul_mm_f32_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>; | ||
| 9275 | template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>; | ||
| 9276 | #if defined(GGML_METAL_HAS_BF16) | ||
| 9277 | template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>; | ||
| 9278 | #endif | ||
| 9279 | template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>; | ||
| 9280 | template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>; | ||
| 9281 | template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>; | ||
| 9282 | template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>; | ||
| 9283 | template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; | ||
| 9284 | template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>; | ||
| 9285 | template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; | ||
| 9286 | template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>; | ||
| 9287 | template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; | ||
| 9288 | template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>; | ||
| 9289 | template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>; | ||
| 9290 | template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; | ||
| 9291 | template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>; | ||
| 9292 | template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>; | ||
| 9293 | template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>; | ||
| 9294 | template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>; | ||
| 9295 | template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>; | ||
| 9296 | template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>; | ||
| 9297 | template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>; | ||
| 9298 | template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>; | ||
| 9299 | |||
| 9300 | template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>; | ||
| 9301 | template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>; | ||
| 9302 | template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>; | ||
| 9303 | template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>; | ||
| 9304 | template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>; | ||
| 9305 | template [[host_name("kernel_mul_mm_q5_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>; | ||
| 9306 | template [[host_name("kernel_mul_mm_q8_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>; | ||
| 9307 | template [[host_name("kernel_mul_mm_mxfp4_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>; | ||
| 9308 | template [[host_name("kernel_mul_mm_q2_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>; | ||
| 9309 | template [[host_name("kernel_mul_mm_q3_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>; | ||
| 9310 | template [[host_name("kernel_mul_mm_q4_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>; | ||
| 9311 | template [[host_name("kernel_mul_mm_q5_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>; | ||
| 9312 | template [[host_name("kernel_mul_mm_q6_K_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>; | ||
| 9313 | template [[host_name("kernel_mul_mm_iq2_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>; | ||
| 9314 | template [[host_name("kernel_mul_mm_iq2_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>; | ||
| 9315 | template [[host_name("kernel_mul_mm_iq3_xxs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>; | ||
| 9316 | template [[host_name("kernel_mul_mm_iq3_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>; | ||
| 9317 | template [[host_name("kernel_mul_mm_iq2_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>; | ||
| 9318 | template [[host_name("kernel_mul_mm_iq1_s_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>; | ||
| 9319 | template [[host_name("kernel_mul_mm_iq1_m_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>; | ||
| 9320 | template [[host_name("kernel_mul_mm_iq4_nl_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>; | ||
| 9321 | template [[host_name("kernel_mul_mm_iq4_xs_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>; | ||
| 9322 | |||
| 9323 | // | ||
| 9324 | // indirect matrix-matrix multiplication | ||
| 9325 | // | ||
| 9326 | |||
| 9327 | typedef decltype(kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>) mul_mm_id; | ||
| 9328 | |||
| 9329 | template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, float, float2x4>; | ||
| 9330 | template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, float, float2x4>; | ||
| 9331 | #if defined(GGML_METAL_HAS_BF16) | ||
| 9332 | template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>; | ||
| 9333 | #endif | ||
| 9334 | template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>; | ||
| 9335 | template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>; | ||
| 9336 | template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>; | ||
| 9337 | template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, float, float2x4>; | ||
| 9338 | template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, float, float2x4>; | ||
| 9339 | template [[host_name("kernel_mul_mm_id_mxfp4_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, float, float2x4>; | ||
| 9340 | template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, float, float2x4>; | ||
| 9341 | template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, float, float2x4>; | ||
| 9342 | template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, float, float2x4>; | ||
| 9343 | template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, float, float2x4>; | ||
| 9344 | template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, float, float2x4>; | ||
| 9345 | template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, float, float2x4>; | ||
| 9346 | template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, float, float2x4>; | ||
| 9347 | template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, float, float2x4>; | ||
| 9348 | template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, float, float2x4>; | ||
| 9349 | template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, float, float2x4>; | ||
| 9350 | template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, float, float2x4>; | ||
| 9351 | template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, float, float2x4>; | ||
| 9352 | template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, float, float2x4>; | ||
| 9353 | template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, float, float2x4>; | ||
| 9354 | |||
| 9355 | template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>; | ||
| 9356 | template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>; | ||
| 9357 | template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>; | ||
| 9358 | template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>; | ||
| 9359 | template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>; | ||
| 9360 | template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1, float, float4x4, half, half2x4>; | ||
| 9361 | template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0, float, float4x4, half, half2x4>; | ||
| 9362 | template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4, float, float4x4, half, half2x4>; | ||
| 9363 | template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float, float4x4, half, half2x4>; | ||
| 9364 | template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float, float4x4, half, half2x4>; | ||
| 9365 | template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float, float4x4, half, half2x4>; | ||
| 9366 | template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float, float4x4, half, half2x4>; | ||
| 9367 | template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float, float4x4, half, half2x4>; | ||
| 9368 | template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float, float4x4, half, half2x4>; | ||
| 9369 | template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float, float4x4, half, half2x4>; | ||
| 9370 | template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float, float4x4, half, half2x4>; | ||
| 9371 | template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float, float4x4, half, half2x4>; | ||
| 9372 | template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float, float4x4, half, half2x4>; | ||
| 9373 | template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float, float4x4, half, half2x4>; | ||
| 9374 | template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float, float4x4, half, half2x4>; | ||
| 9375 | template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl, float, float4x4, half, half2x4>; | ||
| 9376 | template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float, float4x4, half, half2x4>; | ||
| 9377 | |||
| 9378 | // | ||
| 9379 | // matrix-vector multiplication | ||
| 9380 | // | ||
| 9381 | |||
| 9382 | typedef void (kernel_mul_mv_disp_t)( | ||
| 9383 | ggml_metal_kargs_mul_mv args, | ||
| 9384 | device const char * src0, | ||
| 9385 | device const char * src1, | ||
| 9386 | device char * dst, | ||
| 9387 | uint3 tgpig, | ||
| 9388 | ushort tiisg); | ||
| 9389 | |||
| 9390 | typedef void (kernel_mul_mv2_disp_t)( | ||
| 9391 | ggml_metal_kargs_mul_mv args, | ||
| 9392 | device const char * src0, | ||
| 9393 | device const char * src1, | ||
| 9394 | device char * dst, | ||
| 9395 | threadgroup char * shmem, | ||
| 9396 | uint3 tgpig, | ||
| 9397 | ushort tiisg, | ||
| 9398 | ushort sgitg); | ||
| 9399 | |||
| 9400 | template<kernel_mul_mv_disp_t disp_fn> | ||
| 9401 | void mmv_fn( | ||
| 9402 | ggml_metal_kargs_mul_mv args, | ||
| 9403 | device const char * src0, | ||
| 9404 | device const char * src1, | ||
| 9405 | device char * dst, | ||
| 9406 | threadgroup char * shmem, | ||
| 9407 | uint3 tgpig, | ||
| 9408 | ushort tiitg, | ||
| 9409 | ushort tiisg, | ||
| 9410 | ushort sgitg) { | ||
| 9411 | disp_fn(args, src0, src1, dst, tgpig, tiisg); | ||
| 9412 | } | ||
| 9413 | |||
| 9414 | template<kernel_mul_mv2_disp_t disp_fn> | ||
| 9415 | void mmv_fn( | ||
| 9416 | ggml_metal_kargs_mul_mv args, | ||
| 9417 | device const char * src0, | ||
| 9418 | device const char * src1, | ||
| 9419 | device char * dst, | ||
| 9420 | threadgroup char * shmem, | ||
| 9421 | uint3 tgpig, | ||
| 9422 | ushort tiitg, | ||
| 9423 | ushort tiisg, | ||
| 9424 | ushort sgitg) { | ||
| 9425 | disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||
| 9426 | } | ||
| 9427 | |||
| 9428 | typedef decltype(mmv_fn<kernel_mul_mv_t_t_disp<half, half, ggml_metal_kargs_mul_mv>>) mul_mv_disp_fn_t; | ||
| 9429 | |||
| 9430 | template<mul_mv_disp_fn_t disp_fn> | ||
| 9431 | kernel void kernel_mul_mv_id( | ||
| 9432 | constant ggml_metal_kargs_mul_mv_id & args, | ||
| 9433 | device const char * src0s, | ||
| 9434 | device const char * src1, | ||
| 9435 | device char * dst, | ||
| 9436 | device const char * ids, | ||
| 9437 | threadgroup char * shmem [[threadgroup(0)]], | ||
| 9438 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 9439 | ushort tiitg[[thread_index_in_threadgroup]], | ||
| 9440 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 9441 | ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||
| 9442 | const int iid1 = tgpig.z/args.nei0; | ||
| 9443 | const int idx = tgpig.z%args.nei0; | ||
| 9444 | |||
| 9445 | tgpig.z = 0; | ||
| 9446 | |||
| 9447 | const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx]; | ||
| 9448 | |||
| 9449 | const int64_t i11 = idx % args.ne11; | ||
| 9450 | const int64_t i12 = iid1; | ||
| 9451 | |||
| 9452 | const int64_t i1 = idx; | ||
| 9453 | const int64_t i2 = i12; | ||
| 9454 | |||
| 9455 | device const char * src0_cur = src0s + i02*args.nb02; | ||
| 9456 | device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; | ||
| 9457 | |||
| 9458 | device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float); | ||
| 9459 | |||
| 9460 | ggml_metal_kargs_mul_mv args0 = { | ||
| 9461 | /*.ne00 =*/ args.ne00, | ||
| 9462 | /*.ne01 =*/ args.ne01, | ||
| 9463 | /*.ne02 =*/ 1, // args.ne02, | ||
| 9464 | /*.nb00 =*/ args.nb00, | ||
| 9465 | /*.nb01 =*/ args.nb01, | ||
| 9466 | /*.nb02 =*/ args.nb02, | ||
| 9467 | /*.nb03 =*/ args.nb02, // args.ne02 == 1 | ||
| 9468 | /*.ne10 =*/ args.ne10, | ||
| 9469 | /*.ne11 =*/ 1, // args.ne11, | ||
| 9470 | /*.ne12 =*/ 1, // args.ne12, | ||
| 9471 | /*.nb10 =*/ args.nb10, | ||
| 9472 | /*.nb11 =*/ args.nb11, | ||
| 9473 | /*.nb12 =*/ args.nb12, | ||
| 9474 | /*.nb13 =*/ args.nb12, // ne12 == 1 | ||
| 9475 | /*.ne0 =*/ args.ne0, | ||
| 9476 | /*.ne1 =*/ 1, // args.ne1, | ||
| 9477 | /*.nr0 =*/ args.nr0, | ||
| 9478 | /*.r2 =*/ 1, | ||
| 9479 | /*.r3 =*/ 1, | ||
| 9480 | }; | ||
| 9481 | |||
| 9482 | disp_fn( | ||
| 9483 | args0, | ||
| 9484 | /* src0 */ src0_cur, | ||
| 9485 | /* src1 */ src1_cur, | ||
| 9486 | /* dst */ dst_cur, | ||
| 9487 | shmem, | ||
| 9488 | tgpig, | ||
| 9489 | tiitg, | ||
| 9490 | tiisg, | ||
| 9491 | sgitg); | ||
| 9492 | } | ||
| 9493 | |||
| 9494 | typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>) kernel_mul_mv_id_t; | ||
| 9495 | |||
| 9496 | typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>) kernel_mul_mv_id_4_t; | ||
| 9497 | |||
| 9498 | template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<float, float>>>; | ||
| 9499 | template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<half, float>>>; | ||
| 9500 | #if defined(GGML_METAL_HAS_BF16) | ||
| 9501 | template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_disp<bfloat, float>>>; | ||
| 9502 | #endif | ||
| 9503 | template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<float, float4, float, float4>>>; | ||
| 9504 | template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<half, half4, float, float4>>>; | ||
| 9505 | #if defined(GGML_METAL_HAS_BF16) | ||
| 9506 | template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_disp<bfloat, bfloat4, float, float4>>>; | ||
| 9507 | #endif | ||
| 9508 | |||
| 9509 | template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>; | ||
| 9510 | |||
| 9511 | template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>; | ||
| 9512 | template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>; | ||
| 9513 | template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>; | ||
| 9514 | template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>; | ||
| 9515 | |||
| 9516 | template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>; | ||
| 9517 | |||
| 9518 | template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>; | ||
| 9519 | template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>; | ||
| 9520 | template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>; | ||
| 9521 | template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K>>>; | ||
| 9522 | template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K>>>; | ||
| 9523 | template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S>>>; | ||
| 9524 | template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M>>>; | ||
| 9525 | template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>; | ||
| 9526 | template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>; | ||
| 9527 | template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>; | ||
| 9528 | template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S>>>; | ||
| 9529 | template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S>>>; | ||
| 9530 | template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>; | ||
| 9531 | template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>; | ||
| 9532 | |||
| 9533 | kernel void kernel_pool_2d_max_f32( | ||
| 9534 | constant ggml_metal_kargs_pool_2d & args, | ||
| 9535 | device const float * src0, | ||
| 9536 | device float * dst, | ||
| 9537 | uint gid[[thread_position_in_grid]]) { | ||
| 9538 | |||
| 9539 | if (gid >= args.np) { | ||
| 9540 | return; | ||
| 9541 | } | ||
| 9542 | |||
| 9543 | const int idx = gid; | ||
| 9544 | const int I_HW = args.IH * args.IW; | ||
| 9545 | const int O_HW = args.OH * args.OW; | ||
| 9546 | const int nc = idx / O_HW; | ||
| 9547 | const int cur_oh = idx % O_HW / args.OW; | ||
| 9548 | const int cur_ow = idx % O_HW % args.OW; | ||
| 9549 | |||
| 9550 | device const float * i_ptr = src0 + nc * I_HW; | ||
| 9551 | device float * o_ptr = dst + nc * O_HW; | ||
| 9552 | |||
| 9553 | const int start_h = cur_oh * args.s1 - args.p1; | ||
| 9554 | const int bh = MAX(0, start_h); | ||
| 9555 | const int eh = MIN(args.IH, start_h + args.k1); | ||
| 9556 | const int start_w = cur_ow * args.s0 - args.p0; | ||
| 9557 | const int bw = MAX(0, start_w); | ||
| 9558 | const int ew = MIN(args.IW, start_w + args.k0); | ||
| 9559 | |||
| 9560 | float res = -INFINITY; | ||
| 9561 | |||
| 9562 | for (int i = bh; i < eh; i += 1) { | ||
| 9563 | for (int j = bw; j < ew; j += 1) { | ||
| 9564 | res = MAX(res, i_ptr[i * args.IW + j]); | ||
| 9565 | } | ||
| 9566 | } | ||
| 9567 | |||
| 9568 | o_ptr[cur_oh * args.OW + cur_ow] = res; | ||
| 9569 | } | ||
| 9570 | |||
| 9571 | kernel void kernel_pool_2d_avg_f32( | ||
| 9572 | constant ggml_metal_kargs_pool_2d & args, | ||
| 9573 | device const float * src0, | ||
| 9574 | device float * dst, | ||
| 9575 | uint gid[[thread_position_in_grid]]) { | ||
| 9576 | |||
| 9577 | if (gid >= args.np) { | ||
| 9578 | return; | ||
| 9579 | } | ||
| 9580 | |||
| 9581 | const int idx = gid; | ||
| 9582 | const int I_HW = args.IH * args.IW; | ||
| 9583 | const int O_HW = args.OH * args.OW; | ||
| 9584 | const int nc = idx / O_HW; | ||
| 9585 | const int cur_oh = idx % O_HW / args.OW; | ||
| 9586 | const int cur_ow = idx % O_HW % args.OW; | ||
| 9587 | |||
| 9588 | device const float * i_ptr = src0 + nc * I_HW; | ||
| 9589 | device float * o_ptr = dst + nc * O_HW; | ||
| 9590 | |||
| 9591 | const int start_h = cur_oh * args.s1 - args.p1; | ||
| 9592 | const int bh = MAX(0, start_h); | ||
| 9593 | const int eh = MIN(args.IH, start_h + args.k1); | ||
| 9594 | const int start_w = cur_ow * args.s0 - args.p0; | ||
| 9595 | const int bw = MAX(0, start_w); | ||
| 9596 | const int ew = MIN(args.IW, start_w + args.k0); | ||
| 9597 | // const float scale = 1. / ((eh - bh) * (ew - bw)); | ||
| 9598 | const float scale = 1. / (args.k0 * args.k1); | ||
| 9599 | |||
| 9600 | float res = 0; | ||
| 9601 | |||
| 9602 | for (int i = bh; i < eh; i += 1) { | ||
| 9603 | for (int j = bw; j < ew; j += 1) { | ||
| 9604 | float cur = i_ptr[i * args.IW + j]; | ||
| 9605 | res += cur * scale; | ||
| 9606 | } | ||
| 9607 | } | ||
| 9608 | |||
| 9609 | o_ptr[cur_oh * args.OW + cur_ow] = res; | ||
| 9610 | } | ||
| 9611 | |||
| 9612 | |||
| 9613 | kernel void kernel_pool_1d_max_f32( | ||
| 9614 | constant ggml_metal_kargs_pool_1d & args, | ||
| 9615 | device const float * src, | ||
| 9616 | device float * dst, | ||
| 9617 | uint gid [[thread_position_in_grid]] | ||
| 9618 | ) { | ||
| 9619 | |||
| 9620 | if (gid >= args.np) { | ||
| 9621 | return; | ||
| 9622 | } | ||
| 9623 | |||
| 9624 | const int ow = (int)gid % args.OW; | ||
| 9625 | const int row = (int)gid / args.OW; | ||
| 9626 | |||
| 9627 | const int base = ow * args.s0 - args.p0; | ||
| 9628 | |||
| 9629 | float acc = -INFINITY; | ||
| 9630 | |||
| 9631 | const int src_off = row * args.IW; | ||
| 9632 | const int dst_off = row * args.OW; | ||
| 9633 | |||
| 9634 | for (int ki = 0; ki < args.k0; ++ki) { | ||
| 9635 | int j = base + ki; | ||
| 9636 | if (j < 0 || j >= args.IW){ | ||
| 9637 | continue; | ||
| 9638 | } | ||
| 9639 | float v = src[src_off + j]; | ||
| 9640 | acc = max(acc, v); | ||
| 9641 | } | ||
| 9642 | |||
| 9643 | dst[dst_off + ow] = acc; | ||
| 9644 | } | ||
| 9645 | |||
| 9646 | kernel void kernel_pool_1d_avg_f32( | ||
| 9647 | constant ggml_metal_kargs_pool_1d & args, | ||
| 9648 | device const float * src, | ||
| 9649 | device float * dst, | ||
| 9650 | uint gid [[thread_position_in_grid]] | ||
| 9651 | ) { | ||
| 9652 | |||
| 9653 | if (gid >= args.np) { | ||
| 9654 | return; | ||
| 9655 | } | ||
| 9656 | |||
| 9657 | const int ow = (int)gid % args.OW; | ||
| 9658 | const int row = (int)gid / args.OW; | ||
| 9659 | |||
| 9660 | const int base = ow * args.s0 - args.p0; | ||
| 9661 | |||
| 9662 | float acc = 0.0f; | ||
| 9663 | int cnt = 0; | ||
| 9664 | |||
| 9665 | const int src_off = row * args.IW; | ||
| 9666 | const int dst_off = row * args.OW; | ||
| 9667 | |||
| 9668 | for (int ki = 0; ki < args.k0; ++ki) { | ||
| 9669 | const int j = base + ki; | ||
| 9670 | if (j < 0 || j >= args.IW) { | ||
| 9671 | continue; | ||
| 9672 | } | ||
| 9673 | acc += src[src_off + j]; | ||
| 9674 | cnt += 1; | ||
| 9675 | } | ||
| 9676 | |||
| 9677 | dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f; | ||
| 9678 | } | ||
| 9679 | |||
| 9680 | kernel void kernel_opt_step_adamw_f32( | ||
| 9681 | constant ggml_metal_kargs_opt_step_adamw & args, | ||
| 9682 | device float * x, | ||
| 9683 | device const float * g, | ||
| 9684 | device float * g_m, | ||
| 9685 | device float * g_v, | ||
| 9686 | device const float * pars, | ||
| 9687 | uint gid[[thread_position_in_grid]]) { | ||
| 9688 | |||
| 9689 | if (gid >= args.np) { | ||
| 9690 | return; | ||
| 9691 | } | ||
| 9692 | |||
| 9693 | const float alpha = pars[0]; | ||
| 9694 | const float beta1 = pars[1]; | ||
| 9695 | const float beta2 = pars[2]; | ||
| 9696 | const float eps = pars[3]; | ||
| 9697 | const float wd = pars[4]; | ||
| 9698 | const float beta1h = pars[5]; | ||
| 9699 | const float beta2h = pars[6]; | ||
| 9700 | |||
| 9701 | const float gi = g[gid]; | ||
| 9702 | const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1); | ||
| 9703 | const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2); | ||
| 9704 | |||
| 9705 | g_m[gid] = gmi; | ||
| 9706 | g_v[gid] = gvi; | ||
| 9707 | |||
| 9708 | const float mh = gmi * beta1h; | ||
| 9709 | const float vh = sqrt(gvi * beta2h) + eps; | ||
| 9710 | |||
| 9711 | x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; | ||
| 9712 | } | ||
| 9713 | |||
| 9714 | kernel void kernel_opt_step_sgd_f32( | ||
| 9715 | constant ggml_metal_kargs_opt_step_sgd & args, | ||
| 9716 | device float * x, | ||
| 9717 | device const float * g, | ||
| 9718 | device const float * pars, | ||
| 9719 | uint gid[[thread_position_in_grid]]) { | ||
| 9720 | |||
| 9721 | if (gid >= args.np) { | ||
| 9722 | return; | ||
| 9723 | } | ||
| 9724 | |||
| 9725 | x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; | ||
| 9726 | } | ||
| 9727 | |||
| 9728 | template<typename T> | ||
| 9729 | kernel void kernel_memset( | ||
| 9730 | constant ggml_metal_kargs_memset & args, | ||
| 9731 | device T * dst, | ||
| 9732 | uint tpig[[thread_position_in_grid]]) { | ||
| 9733 | dst[tpig] = args.val; | ||
| 9734 | } | ||
| 9735 | |||
| 9736 | typedef decltype(kernel_memset<int64_t>) kernel_memset_t; | ||
| 9737 | |||
| 9738 | template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>; | ||
| 9739 | |||
| 9740 | constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]]; | ||
| 9741 | |||
| 9742 | template<typename T> | ||
| 9743 | kernel void kernel_count_equal( | ||
| 9744 | constant ggml_metal_kargs_count_equal & args, | ||
| 9745 | device const char * src0, | ||
| 9746 | device const char * src1, | ||
| 9747 | device atomic_int * dst, | ||
| 9748 | threadgroup int32_t * shmem_i32 [[threadgroup(0)]], | ||
| 9749 | uint3 tgpig[[threadgroup_position_in_grid]], | ||
| 9750 | ushort3 tpitg[[thread_position_in_threadgroup]], | ||
| 9751 | ushort sgitg[[simdgroup_index_in_threadgroup]], | ||
| 9752 | ushort tiisg[[thread_index_in_simdgroup]], | ||
| 9753 | ushort3 ntg[[threads_per_threadgroup]]) { | ||
| 9754 | const short NSG = FC_count_equal_nsg; | ||
| 9755 | |||
| 9756 | const int i3 = tgpig.z; | ||
| 9757 | const int i2 = tgpig.y; | ||
| 9758 | const int i1 = tgpig.x; | ||
| 9759 | |||
| 9760 | if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { | ||
| 9761 | return; | ||
| 9762 | } | ||
| 9763 | |||
| 9764 | int sum = 0; | ||
| 9765 | |||
| 9766 | device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03; | ||
| 9767 | device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13; | ||
| 9768 | |||
| 9769 | for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { | ||
| 9770 | const T v0 = *(device const T *)(base0 + i0*args.nb00); | ||
| 9771 | const T v1 = *(device const T *)(base1 + i0*args.nb10); | ||
| 9772 | sum += (v0 == v1); | ||
| 9773 | } | ||
| 9774 | |||
| 9775 | sum = simd_sum(sum); | ||
| 9776 | |||
| 9777 | if (tiisg == 0) { | ||
| 9778 | shmem_i32[sgitg] = sum; | ||
| 9779 | } | ||
| 9780 | |||
| 9781 | threadgroup_barrier(mem_flags::mem_threadgroup); | ||
| 9782 | |||
| 9783 | if (sgitg == 0) { | ||
| 9784 | float v = 0.0f; | ||
| 9785 | if (tpitg.x < NSG) { | ||
| 9786 | v = shmem_i32[tpitg.x]; | ||
| 9787 | } | ||
| 9788 | |||
| 9789 | float total = simd_sum(v); | ||
| 9790 | if (tpitg.x == 0) { | ||
| 9791 | atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed); | ||
| 9792 | } | ||
| 9793 | } | ||
| 9794 | } | ||
| 9795 | |||
| 9796 | typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t; | ||
| 9797 | |||
| 9798 | template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>; | ||
