diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda')
235 files changed, 35073 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt b/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt new file mode 100644 index 0000000..262f882 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt @@ -0,0 +1,259 @@ +cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES + +find_package(CUDAToolkit) + +if (CUDAToolkit_FOUND) + message(STATUS "CUDA Toolkit found") + + if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + # native == GPUs available at build time + # 50 == Maxwell, lowest CUDA 12 standard + # 60 == P100, FP16 CUDA intrinsics + # 61 == Pascal, __dp4a instruction (per-byte integer dot product) + # 70 == V100, FP16 tensor cores + # 75 == Turing, int8 tensor cores + # 80 == Ampere, asynchronous data loading, faster tensor core instructions + # 86 == RTX 3000, needs CUDA v11.1 + # 89 == RTX 4000, needs CUDA v11.8 + # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores + # + # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run + # XX-real == compile CUDA code as device code for this specific architecture + # no suffix == compile as both PTX and device code + # + # The default behavior for a non-native is to build virtual architectures as needed to cover all features needed + # for best performance and to also build real architectures for the most commonly used GPUs. + if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + set(CMAKE_CUDA_ARCHITECTURES "native") + else() + if (CUDAToolkit_VERSION VERSION_LESS "13") + list(APPEND CMAKE_CUDA_ARCHITECTURES 50-virtual 61-virtual 70-virtual) + endif () + + list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) + + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) + endif() + + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") + # The CUDA architecture 120f-virtual would in principle work for Blackwell support + # but the newly added "f" suffix conflicted with a preexising regex for validating CUDA architectures in CMake. + # So either a recent CMake version or one with the backported fix is needed. + # The following versions should work: + # - CMake >= v3.31.8 && CMake < v4.0.0 + # - CMake >= v4.0.2 + # This is NOT documented in the CMake release notes, + # check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead. + # However, the architectures 120a-real and 121a-real should work with basically any CMake version and + # until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell. + list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real) + endif() + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9") + list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real) + endif() + endif() + endif() + + enable_language(CUDA) + + # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit + if (GGML_CUDA_CUB_3DOT2) + include(FetchContent) + + FetchContent_Declare( + CCCL + GIT_REPOSITORY https://github.com/nvidia/cccl.git + GIT_TAG v3.2.0 + GIT_SHALLOW TRUE + ) + + FetchContent_MakeAvailable(CCCL) + endif() + + # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa. + # 12X is forwards-compatible, 12Xa is not. + # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa. + # But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code. + # So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released. + foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE) + set(FIXED_ARCHS "") + foreach(ARCH IN LISTS ${ARCHS}) + if (ARCH MATCHES "^12[0-9](-real|-virtual)?$") + string(REGEX REPLACE "^(12[0-9])((-real|-virtual)?)$" "\\1a\\2" FIXED_ARCH ${ARCH}) + message(STATUS "Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}") + list(APPEND FIXED_ARCHS "${FIXED_ARCH}") + else() + list(APPEND FIXED_ARCHS "${ARCH}") + endif() + endforeach() + set(${ARCHS} ${FIXED_ARCHS}) + endforeach() + + # If we try to compile a "native" build it will use the 12X architectures and fail. + # So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa. + # But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use. + if (CMAKE_CUDA_ARCHITECTURES STREQUAL "native" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES "^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$") + set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE}) + endif() + message(STATUS "Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}") + + file(GLOB GGML_HEADERS_CUDA "*.cuh") + list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") + + file(GLOB GGML_SOURCES_CUDA "*.cu") + file(GLOB SRCS "template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-mma*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmq*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmf*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + + if (GGML_CUDA_FA_ALL_QUANTS) + file(GLOB SRCS "template-instances/fattn-vec*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) + else() + file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) + endif() + + ggml_add_backend_library(ggml-cuda + ${GGML_HEADERS_CUDA} + ${GGML_SOURCES_CUDA} + ) + + add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) + + if (GGML_CUDA_GRAPHS) + add_compile_definitions(GGML_CUDA_USE_GRAPHS) + endif() + + if (GGML_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + + if (GGML_CUDA_FORCE_CUBLAS) + add_compile_definitions(GGML_CUDA_FORCE_CUBLAS) + endif() + + if (GGML_CUDA_NO_VMM) + add_compile_definitions(GGML_CUDA_NO_VMM) + endif() + + if (NOT GGML_CUDA_FA) + add_compile_definitions(GGML_CUDA_NO_FA) + endif() + + if (GGML_CUDA_NO_PEER_COPY) + add_compile_definitions(GGML_CUDA_NO_PEER_COPY) + endif() + + if (GGML_STATIC) + if (WIN32) + # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) + else () + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) + endif() + endif() + else() + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) + endif() + + if (GGML_CUDA_NO_VMM) + # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so) + else() + target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver) + endif() + + set(CUDA_CXX_FLAGS "") + + set(CUDA_FLAGS -use_fast_math -extended-lambda) + + if (GGML_CUDA_DEBUG) + list(APPEND CUDA_FLAGS -lineinfo) + add_compile_definitions(GGML_CUDA_DEBUG) + endif() + + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") + # Options are: + # - none (not recommended) + # - speed (nvcc's default) + # - balance + # - size + list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE}) + endif() + + if (GGML_FATAL_WARNINGS) + list(APPEND CUDA_FLAGS -Werror all-warnings) + endif() + + if (GGML_ALL_WARNINGS AND NOT MSVC) + set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c) + if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "") + list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER}) + endif() + + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler --version + OUTPUT_VARIABLE CUDA_CCFULLVER + ERROR_QUIET + ) + + if (NOT CUDA_CCFULLVER MATCHES clang) + set(CUDA_CCID "GNU") + execute_process( + COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion" + OUTPUT_VARIABLE CUDA_CCVER + ERROR_QUIET + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + else() + if (CUDA_CCFULLVER MATCHES Apple) + set(CUDA_CCID "AppleClang") + else() + set(CUDA_CCID "Clang") + endif() + string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER}) + endif() + + message(STATUS "CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}") + + ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER}) + list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later + endif() + + if (NOT MSVC) + list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) + else() + # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC + # https://github.com/NVIDIA/cccl/pull/6827 + list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor) + endif() + + list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument + + if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "") + list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED}) + endif() + + target_compile_options(ggml-cuda PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>") +else() + message(FATAL_ERROR "CUDA Toolkit not found") +endif() diff --git a/llama.cpp/ggml/src/ggml-cuda/acc.cu b/llama.cpp/ggml/src/ggml-cuda/acc.cu new file mode 100644 index 0000000..e084607 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/acc.cu @@ -0,0 +1,61 @@ +#include "acc.cuh" + +static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + float val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; + } + dst[i] = val; +} + +static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) { + const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE; + acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); +} + +void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); + + const int64_t s1 = dst->op_params[0] / sizeof(float); + const int64_t s2 = dst->op_params[1] / sizeof(float); + const int64_t s3 = dst->op_params[2] / sizeof(float); + const int64_t offset = dst->op_params[3] / sizeof(float); + + acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/acc.cuh b/llama.cpp/ggml/src/ggml-cuda/acc.cuh new file mode 100644 index 0000000..1168ea1 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/acc.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_ACC_BLOCK_SIZE 256 + +void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/add-id.cu b/llama.cpp/ggml/src/ggml-cuda/add-id.cu new file mode 100644 index 0000000..8d9cf69 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/add-id.cu @@ -0,0 +1,58 @@ +#include "add-id.cuh" + +static __global__ void add_id_kernel( + const float * src0, const float * src1, const int32_t * src2, float * dst, + int64_t ne0, int64_t ne1, + size_t nb01, size_t nb02, + size_t nb11, + size_t nb21 + ) { + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + + const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21); + + const size_t nb1 = ne0 * sizeof(float); + const size_t nb2 = ne1 * nb1; + + float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2); + const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02); + const float * src1_row = (const float *)((const char *)src1 + i11*nb11); + + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + dst_row[i0] = src0_row[i0] + src1_row[i0]; + } +} + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_TENSOR_TERNARY_OP_LOCALS + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb20 == sizeof(int32_t)); + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + const int32_t * src2_d = (const int32_t *)src2->data; + float * dst_d = (float *)dst->data; + + int threads = std::min((int)ne00, 768); // cols + dim3 blocks(ne01, ne02); // n_experts_used, n_tokens + add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>( + src0_d, src1_d, src2_d, dst_d, + ne0, ne1, + nb01, nb02, + nb11, + nb21 + ); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/add-id.cuh b/llama.cpp/ggml/src/ggml-cuda/add-id.cuh new file mode 100644 index 0000000..30b1721 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/add-id.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/arange.cu b/llama.cpp/ggml/src/ggml-cuda/arange.cu new file mode 100644 index 0000000..b5e495a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/arange.cu @@ -0,0 +1,34 @@ +#include "arange.cuh" + +static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) { + // blockIDx.x: idx of ne0 / BLOCK_SIZE + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + dst[nidx] = start + step * nidx; +} + +static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE; + arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step); +} + +void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + float start; + float stop; + float step; + memcpy(&start, (float *)dst->op_params + 0, sizeof(float)); + memcpy(&stop, (float *)dst->op_params + 1, sizeof(float)); + memcpy(&step, (float *)dst->op_params + 2, sizeof(float)); + + int64_t steps = (int64_t)ceil((stop - start) / step); + GGML_ASSERT(ggml_nelements(dst) == steps); + + arange_f32_cuda(dst_d, dst->ne[0], start, step, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/arange.cuh b/llama.cpp/ggml/src/ggml-cuda/arange.cuh new file mode 100644 index 0000000..41e74fd --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/arange.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_ARANGE_BLOCK_SIZE 256 + +void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/argmax.cu b/llama.cpp/ggml/src/ggml-cuda/argmax.cu new file mode 100644 index 0000000..51967c6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/argmax.cu @@ -0,0 +1,91 @@ +#include <algorithm> +#include <cstdint> + +#include "argmax.cuh" +#include "common.cuh" +#include "sum.cuh" + +static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { + const int64_t row = blockIdx.x; + + float maxval = -FLT_MAX; + int argmax = -1; + const float * rowx = x + row * ncols; + + for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { + const float val = rowx[col]; + if (val > maxval) { + maxval = val; + argmax = col; + } + } + +#pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + + const int n_warps = blockDim.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + if (n_warps > 1) { + constexpr int max_warps = 1024 / WARP_SIZE; + __shared__ float shared_maxval[max_warps]; + __shared__ int shared_argmax[max_warps]; + if (lane_id == 0) { + shared_maxval[warp_id] = maxval; + shared_argmax[warp_id] = argmax; + } + + __syncthreads(); + + if (warp_id == 0) { + if (lane_id < n_warps) { + maxval = shared_maxval[lane_id]; + argmax = shared_argmax[lane_id]; + } +#pragma unroll + for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + } + } + + if (warp_id == 0 && lane_id == 0) { + dst[row] = argmax; + } +} + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + int32_t * dst_d = (int32_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t num_blocks = nrows; + const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); + const dim3 blocks_dim(num_threads, 1, 1); + const dim3 blocks_num(num_blocks, 1, 1); + + argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/argmax.cuh b/llama.cpp/ggml/src/ggml-cuda/argmax.cuh new file mode 100644 index 0000000..5b7223a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/argmax.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/argsort.cu b/llama.cpp/ggml/src/ggml-cuda/argsort.cu new file mode 100644 index 0000000..4896669 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/argsort.cu @@ -0,0 +1,230 @@ +#include "argsort.cuh" + +#ifdef GGML_CUDA_USE_CUB +# include <cub/cub.cuh> +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) +# define STRIDED_ITERATOR_AVAILABLE +# endif +using namespace cub; +#endif // GGML_CUDA_USE_CUB + +static __global__ void init_indices(int * indices, const int ncols, const int nrows) { + const int col = blockIdx.x * blockDim.x + threadIdx.x; + const int row = blockIdx.y; + + if (col < ncols && row < nrows) { + indices[row * ncols + col] = col; + } +} + +#ifndef STRIDED_ITERATOR_AVAILABLE +static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx <= nrows) { + offsets[idx] = idx * ncols; + } +} +#endif // STRIDED_ITERATOR_AVAILABLE + +#ifdef GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { + ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows); + ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows); + + int * temp_indices = temp_indices_alloc.get(); + float * temp_keys = temp_keys_alloc.get(); + + static const int block_size = 256; + const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); + init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows); + +#ifdef STRIDED_ITERATOR_AVAILABLE + auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); +#else + ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1); + int * offset_iterator = offsets_alloc.get(); + const dim3 offset_grid((nrows + block_size - 1) / block_size); + init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows); +#endif + CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + + size_t temp_storage_bytes = 0; + + if (order == GGML_SORT_ORDER_ASC) { + if (nrows == 1) { + DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, stream); + } + } else { + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, + stream); + } + } + + ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + if (order == GGML_SORT_ORDER_ASC) { + if (nrows == 1) { + DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream); + } + } else { + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream); + } + } +} +#endif // GGML_CUDA_USE_CUB + +// Bitonic sort implementation +template<typename T> +static inline __device__ void ggml_cuda_swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template<ggml_sort_order order> +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.x; + + if (col >= ncols_pad) { + return; + } + + const float * x_row = x + row * ncols; + extern __shared__ int dst_row[]; + + // initialize indices + dst_row[col] = col; + + __syncthreads(); + + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } else { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { + ggml_cuda_swap(dst_row[col], dst_row[ixj]); + } + } + } + __syncthreads(); + } + } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } +} + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + const dim3 block_dims(ncols_pad, 1, 1); + const dim3 block_nums(nrows, 1, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + // FIXME: this limit could be raised by ~2-4x on Ampere or newer + GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); + + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_f32_i32<GGML_SORT_ORDER_ASC> + <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_f32_i32<GGML_SORT_ORDER_DESC> + <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad); + } else { + GGML_ABORT("fatal error"); + } +} + +void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + +#ifdef GGML_CUDA_USE_CUB + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + if (shared_mem > max_shared_mem || ncols > 1024) { + ggml_cuda_pool & pool = ctx.pool(); + argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); + } +#else + argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream); +#endif +} diff --git a/llama.cpp/ggml/src/ggml-cuda/argsort.cuh b/llama.cpp/ggml/src/ggml-cuda/argsort.cuh new file mode 100644 index 0000000..22b7306 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/argsort.cuh @@ -0,0 +1,19 @@ +#include "common.cuh" + +void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#ifdef GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); +#endif // GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); diff --git a/llama.cpp/ggml/src/ggml-cuda/binbcast.cu b/llama.cpp/ggml/src/ggml-cuda/binbcast.cu new file mode 100644 index 0000000..7339fe0 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/binbcast.cu @@ -0,0 +1,504 @@ +#include "binbcast.cuh" +#include <cstdint> +#include <utility> + +static __device__ __forceinline__ float op_repeat(const float a, const float b) { + return b; + GGML_UNUSED(a); +} + +static __device__ __forceinline__ float op_add(const float a, const float b) { + return a + b; +} + +static __device__ __forceinline__ float op_sub(const float a, const float b) { + return a - b; +} + +static __device__ __forceinline__ float op_mul(const float a, const float b) { + return a * b; +} + +static __device__ __forceinline__ float op_div(const float a, const float b) { + return a / b; +} + +template <float (*bin_op)(const float, const float), + typename src0_t, + typename src1_t, + typename dst_t, + typename... src1_ptrs> +static __global__ void k_bin_bcast(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const int ne0, + const int ne1, + const int ne2, + const uint3 ne3, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*const int s0,*/ + const int s1, + const int s2, + const int s3, + const int s00, + const int s01, + const int s02, + const int s03, + const int s10, + const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { + const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; + const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); + const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); + const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z); + + if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { + return; + } + + const uint32_t i11 = fastmodulo(i1, ne11); + const uint32_t i12 = fastmodulo(i2, ne12); + const uint32_t i13 = fastmodulo(i3, ne13); + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + const uint32_t i10 = fastmodulo(i0, ne10); + + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10*s10]); + } + + dst_row[i0] = (dst_t) result; + } +} + +template <float (*bin_op)(const float, const float), + typename src0_t, + typename src1_t, + typename dst_t, + typename... src1_ptrs> +static __global__ void k_bin_bcast_unravel(const src0_t * src0, + const src1_t * src1, + dst_t * dst, + const uint3 ne0, + const uint3 ne1, + const uint3 ne2, + const uint32_t ne3, + const uint3 prod_012, + const uint3 prod_01, + const uint3 ne10, + const uint3 ne11, + const uint3 ne12, + const uint3 ne13, + /*const int s0,*/ + const int s1, + const int s2, + const int s3, + const int s00, + const int s01, + const int s02, + const int s03, + const int s10, + const int s11, + const int s12, + const int s13, + src1_ptrs... src1s) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + const uint32_t i3 = fastdiv(i, prod_012); + const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); + const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0); + const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z; + + if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) { + return; + } + + const int i11 = fastmodulo(i1, ne11); + const int i12 = fastmodulo(i2, ne12); + const int i13 = fastmodulo(i3, ne13); + + const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + + const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; + dst_t * dst_row = dst + i_dst; + + const int i10 = fastmodulo(i0, ne10); + + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; + if constexpr (sizeof...(src1_ptrs) > 0) { + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + } else { + result = bin_op(result, (float)src1[i_src1 + i10*s10]); + } + + dst_row[i0] = (dst_t) result; +} + +template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I> +static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream, std::index_sequence<I...>) { + GGML_TENSOR_BINARY_OP_LOCALS + + int nr0 = ne10 / ne0; + int nr1 = ne11 / ne1; + int nr2 = ne12 / ne2; + int nr3 = ne13 / ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + int64_t cne[] = { ne0, ne1, ne2, ne3 }; + int64_t cne0[] = { ne00, ne01, ne02, ne03 }; + int64_t cne1[] = { ne10, ne11, ne12, ne13 }; + + size_t cnb[] = { nb0, nb1, nb2, nb3 }; + size_t cnb0[] = { nb00, nb01, nb02, nb03 }; + size_t cnb1[] = { nb10, nb11, nb12, nb13 }; + + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], const int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb, cne); + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne); + collapse(cne0); + collapse(cne1); + } + } + } + + { + int64_t ne0 = cne[0]; + int64_t ne1 = cne[1]; + int64_t ne2 = cne[2]; + int64_t ne3 = cne[3]; + + //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00); + //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01); + //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); + //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); + + size_t nb0 = cnb[0]; + size_t nb1 = cnb[1]; + size_t nb2 = cnb[2]; + size_t nb3 = cnb[3]; + + size_t nb00 = cnb0[0]; + size_t nb01 = cnb0[1]; + size_t nb02 = cnb0[2]; + size_t nb03 = cnb0[3]; + + size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + //size_t s0 = nb0 / sizeof(dst_t); + size_t s1 = nb1 / sizeof(dst_t); + size_t s2 = nb2 / sizeof(dst_t); + size_t s3 = nb3 / sizeof(dst_t); + + size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + size_t s00 = nb00 / sizeof(src0_t); + size_t s01 = nb01 / sizeof(src0_t); + size_t s02 = nb02 / sizeof(src0_t); + size_t s03 = nb03 / sizeof(src0_t); + + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); + GGML_ASSERT(nb1 % sizeof(dst_t) == 0); + GGML_ASSERT(nb2 % sizeof(dst_t) == 0); + GGML_ASSERT(nb3 % sizeof(dst_t) == 0); + + GGML_ASSERT(nb00 % sizeof(src0_t) == 0); + GGML_ASSERT(nb01 % sizeof(src0_t) == 0); + GGML_ASSERT(nb02 % sizeof(src0_t) == 0); + GGML_ASSERT(nb03 % sizeof(src0_t) == 0); + + GGML_ASSERT(nb10 % sizeof(src1_t) == 0); + GGML_ASSERT(nb11 % sizeof(src1_t) == 0); + GGML_ASSERT(nb12 % sizeof(src1_t) == 0); + GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + + const int block_size = 128; + + int64_t hne0 = std::max(ne0 / 2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min<unsigned int>(hne0, block_size); + block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y, + (ne2 * ne3 + block_dims.z - 1) / block_dims.z); + + const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]); + const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]); + const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]); + const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); + + if (block_nums.z > 65535 || block_nums.y > 65535) { + int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); + const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); + const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); + const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); + const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); + + if constexpr (sizeof...(I) > 0) { + k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>( + src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, + ne12, ne13, + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> + <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, + ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); + } + } else { + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); + if constexpr (sizeof...(I) > 0) { + k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /*s0,*/ s1, s2, s3, + s00 ,s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + } else { + k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>( + src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); + } + } + } +} + +template <typename T> +static __global__ void k_repeat_back( + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) { + + const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x; + const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y; + const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z; + const int64_t tid2 = tid23 % ne2; + const int64_t tid3 = tid23 / ne2; + + if (tid0 >= ne0) { + return; + } + + T sum = 0; + for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00]; + } + } + } + } + dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; +} + +template <float (*bin_op)(const float, const float), int n_fuse = 1> +struct bin_bcast_cuda { + template<typename src0_t, typename src1_t, typename dst_t> + void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream) { + launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>( + src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{}); + } +}; + +template <typename T> +static void repeat_back_cuda( + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const size_t s00, const size_t s01, const size_t s02, const size_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3); + k_repeat_back<T><<<block_nums, block_dims, 0, stream>>> + (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3); +} + +template<class op> +static void ggml_cuda_op_bin_bcast( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + +void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); +} + +void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} + +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} + +void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} + +void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); +} + +template <float (*op)(const float, const float), int n_fuse> +static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + cudaStream_t stream = ctx.stream(); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst, + (const float *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence<n_fuse>{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst, + (const half *) src0->data, (const half *) src1->data, (half *) dst->data, + stream, std::make_index_sequence<n_fuse>{}); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (half *) dst->data, + stream, std::make_index_sequence<n_fuse>{}); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst, + (const half *) src0->data, (const float *) src1->data, (float *) dst->data, + stream, std::make_index_sequence<n_fuse>{}); + } else { + fprintf(stderr, + "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n", + __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); + GGML_ABORT("fatal error"); + } +} + + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + +void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_can_repeat(dst, src0)); + + cudaStream_t stream = ctx.stream(); + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(ne2*ne3 <= (1 << 15)); + + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = nb00 / ts; + const size_t s01 = nb01 / ts; + const size_t s02 = nb02 / ts; + const size_t s03 = nb03 / ts; + + switch (dst->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream); + } break; + default: { + GGML_ASSERT(false); + } break; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh b/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh new file mode 100644 index 0000000..62bc950 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh @@ -0,0 +1,11 @@ +#include "common.cuh" + +void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/llama.cpp/ggml/src/ggml-cuda/clamp.cu b/llama.cpp/ggml/src/ggml-cuda/clamp.cu new file mode 100644 index 0000000..fe415e7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/clamp.cu @@ -0,0 +1,45 @@ +#include "clamp.cuh" + +static __device__ __forceinline__ float op_clamp(float x, float min, float max) { + return fminf(fmaxf(x, min), max); +} + +template <class T> +static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max); +} + +template <class T> +static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; + op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k); +} + + +void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + float min; + float max; + memcpy(&min, dst->op_params, sizeof(float)); + memcpy(&max, (float *) dst->op_params + 1, sizeof(float)); + + if (src0->type == GGML_TYPE_F16) { + clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream); + } else { + clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/clamp.cuh b/llama.cpp/ggml/src/ggml-cuda/clamp.cuh new file mode 100644 index 0000000..7f9559d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/clamp.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CLAMP_BLOCK_SIZE 256 + +void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/common.cuh b/llama.cpp/ggml/src/ggml-cuda/common.cuh new file mode 100644 index 0000000..a3256d5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/common.cuh @@ -0,0 +1,1440 @@ +#pragma once + +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-cuda.h" + +#include <cstdint> +#include <memory> + +#if defined(GGML_USE_HIP) +#define GGML_COMMON_DECL_HIP +#define GGML_COMMON_IMPL_HIP +#else +#define GGML_COMMON_DECL_CUDA +#define GGML_COMMON_IMPL_CUDA +#if defined(GGML_USE_MUSA) +#define GGML_COMMON_DECL_MUSA +#define GGML_COMMON_IMPL_MUSA +#endif +#endif +#include "ggml-common.h" + +#include <array> +#include <algorithm> +#include <cassert> +#include <cfloat> +#include <cstdio> +#include <string> +#include <unordered_map> +#include <vector> + +#if defined(GGML_USE_HIP) +#include "vendors/hip.h" +#elif defined(GGML_USE_MUSA) +#include "vendors/musa.h" +#else +#include "vendors/cuda.h" +#endif // defined(GGML_USE_HIP) + +#define STRINGIZE_IMPL(...) #__VA_ARGS__ +#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) + +#define WARP_SIZE 32 +#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed) +#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons + +#define GGML_CUDA_CC_PASCAL 600 +#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products +#define GGML_CUDA_CC_VOLTA 700 +#define GGML_CUDA_CC_TURING 750 +#define GGML_CUDA_CC_AMPERE 800 +#define GGML_CUDA_CC_ADA_LOVELACE 890 +// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see +// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms +#define GGML_CUDA_CC_BLACKWELL 1200 +#define GGML_CUDA_CC_DGX_SPARK 1210 +#define GGML_CUDA_CC_RUBIN 1300 +#define GGML_CUDA_CC_OFFSET_AMD 0x1000000 +#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 +#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS) + +// AMD +// GCN/CDNA, wave size is 64 +#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 +#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue +#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a +#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 + +// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 +#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 +#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a +#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA +#define GGML_CUDA_CC_RDNA3_5 (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops. +#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000 + +#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD) +#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2) +#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) +#define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5) +#define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_RDNA3(cc) (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) +#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) +#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) + +// Moore Threads +#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons + +#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 +#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 +#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000 + +#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) +#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) +#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1) +#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1) + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +# define GGML_CUDA_USE_CUB +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 + +#ifdef __CUDA_ARCH_LIST__ +constexpr bool ggml_cuda_has_arch_impl(int) { + return false; +} + +template<class ... Archs> +constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) { + return arch == first || ggml_cuda_has_arch_impl(arch, rest...); +} + +constexpr bool ggml_cuda_has_arch(const int arch) { + return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__); +} + +constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) { + if (cur == 0) { + return -1; + } + return cur; +} + +template<class ... Archs> +constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) { + if (first <= arch && first > cur) { + return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...); + } else { + return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...); + } +} + +constexpr int ggml_cuda_highest_compiled_arch(const int arch) { + return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__); +} +#else +static int ggml_cuda_highest_compiled_arch(const int arch) { + return arch; +} +#endif // __CUDA_ARCH_LIST__ + +// --------------------------------------------------------------------------------------------------------- + +#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses + +#define GGML_CUDA_MAX_STREAMS 8 + +[[noreturn]] +void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg); + +#define CUDA_CHECK_GEN(err, success, error_fn) \ + do { \ + auto err_ = (err); \ + if (err_ != (success)) { \ + ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \ + } \ + } while (0) + +#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) + +#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) + static const char * cublas_get_error_str(const cublasStatus_t err) { + return cublasGetStatusString(err); + } +#else + static const char * cublas_get_error_str(const cublasStatus_t err) { + switch (err) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + default: return "unknown error"; + } + } +#endif // CUDART_VERSION >= 12000 + +#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) + +#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) +static const char * cu_get_error_str(CUresult err) { + const char * err_str; + cuGetErrorString(err, &err_str); + return err_str; +} +#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str) +#endif + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \ + const int id = ggml_cuda_get_device(); \ + if (!shared_memory_limit_raised[id]) { \ + CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \ + shared_memory_limit_raised[id] = true; \ + } \ + } while (0) +#else +# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \ + do { \ + GGML_UNUSED(nbytes); \ + } while (0) +#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +#define GGML_CUDA_ASSUME(x) __builtin_assume(x) +#else +#define GGML_CUDA_ASSUME(x) +#endif // CUDART_VERSION >= 11010 + +#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) +#define GGML_USE_VMM +#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) + +#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +#define FP16_AVAILABLE +#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL + +#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 +#define FAST_FP16_AVAILABLE +#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 + +#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +#define AMD_MFMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) + +#if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3)) +#define AMD_WMMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(RDNA4) + +// The Volta instructions are in principle available on Turing or newer but they are effectively unusable: +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#define VOLTA_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING +#define TURING_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define AMPERE_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN +# define BLACKWELL_MMA_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL + +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#define CP_ASYNC_AVAILABLE +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + +#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) +#define FLASH_ATTN_AVAILABLE +#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) + +#if defined(TURING_MMA_AVAILABLE) +#define LDMATRIX_TRANS_AVAILABLE +#endif // defined(TURING_MMA_AVAILABLE) + +static bool fp16_available(const int cc) { + return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); +} + +static bool fast_fp16_available(const int cc) { + return GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc)); +} + +// To be used for feature selection of external libraries, e.g. cuBLAS. +static bool fast_fp16_hardware_available(const int cc) { + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); +} + +// To be used for feature selection of external libraries, e.g. cuBLAS. +static bool fp16_mma_hardware_available(const int cc) { + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); +} + +static bool bf16_mma_hardware_available(const int cc) { + return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || + GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); +} + +static bool fp32_mma_hardware_available(const int cc) { + return GGML_CUDA_CC_IS_CDNA(cc); +} + +static bool amd_mfma_available(const int cc) { +#if !defined(GGML_HIP_NO_MMQ_MFMA) + return GGML_CUDA_CC_IS_CDNA(cc); +#else + return false; +#endif //!defined(GGML_HIP_NO_MMQ_MFMA) +} + +static bool amd_wmma_available(const int cc) { + return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc)); +} + +static bool volta_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA; +} + +static bool turing_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; +} + +static bool ampere_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; +} + +static bool cp_async_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE; +} + +static bool blackwell_mma_available(const int cc) { + return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL && + ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN; +} + +static constexpr __device__ int ggml_cuda_get_physical_warp_size() { +#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) + return 64; +#else + return 32; +#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__)) +} + +// Maximum number of bytes that can be copied in a single instruction. +static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() { +#ifdef GGML_USE_HIP + return 16; +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 16; +#else + return 8; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // GGML_USE_HIP +} + + +[[noreturn]] +static __device__ void no_device_code( + const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) { + +#if defined(GGML_USE_HIP) + printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n", + file_name, line, function_name, arch); + GGML_UNUSED(arch_list); +#else + printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n", + file_name, line, function_name, arch, arch_list); +#endif // defined(GGML_USE_HIP) + __trap(); + + GGML_UNUSED(no_device_code); // suppress unused function warning + +#if defined(GGML_USE_MUSA) + __builtin_unreachable(); +#endif // defined(GGML_USE_MUSA) +} + +#ifdef __CUDA_ARCH__ +#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__)) +#else +#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.") +#endif // __CUDA_ARCH__ + +// The compiler is always able to unroll loops if they contain continue expressions. +// In such cases loop unrolling can still be achieved via recursion: +template <int n> +struct ggml_cuda_unroll { + template <typename Func, typename... Args> + __device__ void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_cuda_unroll<n - 1>{}(f, args...); + } +}; + +template <> +struct ggml_cuda_unroll<1> { + template <typename Func, typename... Args> + __device__ void operator()(const Func & f, Args... args) const { + f(0, args...); + } +}; + +template<int width = WARP_SIZE> +static __device__ __forceinline__ int warp_reduce_sum(int x) { +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + return __reduce_add_sync(0xffffffff, x); +#else +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, offset, width); + } + return x; +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width); + a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width); + } + return a; +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#ifdef FP16_AVAILABLE +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width)); + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ int warp_reduce_all(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __all_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; + } + return x; + } +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ int warp_reduce_any(int x) { + if (width == ggml_cuda_get_physical_warp_size()) { + return __any_sync(0xffffffff, x); + } else { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = __shfl_xor_sync(0xffffffff, x, offset, width) || x; + } + return x; + } +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width)); + } + return x; +} + +template<typename T, int width = WARP_SIZE> +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(0xffffffff, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width); + const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const half2 t = __shfl_up_sync(0xffffffff, a, offset, width); + if (lane_id >= offset) { + a = __hadd2(a, t); + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + +enum class block_reduce_method { + MAX, + SUM, +}; + +template<block_reduce_method method_t, typename T> +struct block_reduce_policy; + +template <typename T, typename... Ts> +inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...); + +template<typename...> +inline constexpr bool ggml_cuda_dependent_false_v = false; + +template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> { + static __device__ T reduce(T val) { + if constexpr(is_any<T, float, float2, half2, int>) { + return warp_reduce_sum(val); + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v<T, float>) { + return 0.0f; + } else if constexpr (std::is_same_v<T, float2>) { + return make_float2(0.0f, 0.0f); + } else if constexpr (std::is_same_v<T, half2>) { + return make_half2(0.0f, 0.0f); + } else if constexpr (std::is_same_v<T, int>) { + return 0; + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum"); + } + } +}; + +template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> { + static __device__ T reduce(T val) { + if constexpr (is_any<T, float, half2>) { + return warp_reduce_max(val); + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v<T, float>) { + return -INFINITY; + } else if constexpr (std::is_same_v<T, half2>) { + return make_half2(-INFINITY, -INFINITY); + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max"); + } + } +}; + +template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T> +static __device__ T block_reduce(T val, T * shared_vals) { + val = block_reduce_policy<reduce_method_t, T>::reduce(val); + const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + if (block_size > WARP_SIZE) { + assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0); + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + __syncthreads(); + val = block_reduce_policy<reduce_method_t, T>::sentinel(); + if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) { + val = shared_vals[lane_id]; + } + return block_reduce_policy<reduce_method_t, T>::reduce(val); + } + + return val; +} + +static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { +#ifdef FP16_AVAILABLE + +#if !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX + return __float2half(fmaxf(__half2float(a), __half2float(b))); +#else + return __hmax(a, b); +#endif // !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX + +#else + NO_DEVICE_CODE; + GGML_UNUSED(b); + return a; +#endif // FP16_AVAILABLE +} + +static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) { +#if defined(GGML_USE_HIP) + return half2(__hmax(a.x, b.x), __hmax(a.y, b.y)); +#elif CUDART_VERSION >= CUDART_HMAX + return __hmax2(a, b); +#else + half2 ret; + reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b))); + reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b))); + return ret; +#endif +} + +template<int width = WARP_SIZE> +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP) +#pragma unroll + for (int offset = width/2; offset > 0; offset >>= 1) { + x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP) +} + +#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \ + (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) +static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) { + const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b))); + const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b))); + return mask_low | mask_high; +} +#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK) + +static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) { +#if defined(GGML_USE_HIP) +#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__) + c = __builtin_amdgcn_sdot4(a, b, c, false); +#elif defined(RDNA3) || defined(RDNA4) + c = __builtin_amdgcn_sudot4( true, a, true, b, c, false); +#elif defined(RDNA1) || defined(__gfx900__) + int tmp1; + int tmp2; + asm("\n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \ + v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \ + v_add3_u32 %0, %1, %2, %0 \n \ + " + : "+v"(c), "=&v"(tmp1), "=&v"(tmp2) + : "v"(a), "v"(b) + ); +#else + const int8x4_t va = reinterpret_cast<const int8x4_t&>(a); + const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b); + c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3]; +#endif + return c; + +#else // defined(GGML_USE_HIP) + +#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) + return __dp4a(a, b, c); +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) + const int8_t * a8 = (const int8_t *) &a; + const int8_t * b8 = (const int8_t *) &b; + return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3]; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA) + +#endif // defined(GGML_USE_HIP) +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) { + acc += v.x*u.x; + acc += v.y*u.y; +} + +#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) +#define V_DOT2_F32_F16_AVAILABLE +#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA)) + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#ifdef V_DOT2_F32_F16_AVAILABLE + asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); +#else +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(v*u); + acc += tmp.x + tmp.y; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + acc += tmpv.x * tmpu.x; + acc += tmpv.y * tmpu.y; +#endif // FAST_FP16_AVAILABLE +#endif // V_DOT2_F32_F16_AVAILABLE +} + +static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) { +#ifdef FAST_FP16_AVAILABLE + acc += v*u; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + float2 tmpacc = __half22float2(acc); + tmpacc.x += tmpv.x * tmpu.x; + tmpacc.y += tmpv.y * tmpu.y; + acc = make_half2(tmpacc.x, tmpacc.y); +#endif // FAST_FP16_AVAILABLE +} + +// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +// Important: do not use this function if dst and src both point at registers. +// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types. +// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions. +// If dst and src point at different address spaces then they are guaranteed to not be aliased. +template <int nbytes, int alignment = 0> +static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { + static_assert( + nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0, + "You are misusing the alignment parameter for ggml_cuda_memcpy_1. " + "The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. " + "If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. " + "Call ggml_cuda_memcpy_1 in a loop instead."); + if constexpr (alignment != 0) { + static_assert(nbytes % alignment == 0, "bad alignment"); + } + constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment; + +#pragma unroll + for (int i = 0; i < nbytes/nb_per_cpy; ++i) { + if constexpr (nb_per_cpy == 1) { + ((char *) dst)[i] = ((const char *) src)[i]; + } else if constexpr (nb_per_cpy == 2) { + ((short *) dst)[i] = ((const short *) src)[i]; + } else if constexpr (nb_per_cpy == 4) { + ((int *) dst)[i] = ((const int *) src)[i]; + } else if constexpr (nb_per_cpy == 8) { + ((int2 *) dst)[i] = ((const int2 *) src)[i]; + } else if constexpr (nb_per_cpy == 16) { + ((int4 *) dst)[i] = ((const int4 *) src)[i]; + } else { + static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + } + } +} + +static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { +#if CUDART_VERSION >= 12080 + const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); + return (float) e; +#else + uint32_t bits; + if (x == 0) { + bits = 0x00400000; + } else { + bits = (uint32_t) x << 23; + } + + float result; + memcpy(&result, &bits, sizeof(float)); + return result; +#endif // CUDART_VERSION >= 12050 +} + +__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { + const uint8_t sign_bit = (x < 0.0f) << 3; + float ax = fabsf(x) * e; + + // Positive LUT + static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f }; + + int best_i = 0; + float best_err = fabsf(ax - pos_lut[0]); + +#pragma unroll + for (int i = 1; i < 8; ++i) { + const float err = fabsf(ax - pos_lut[i]); + if (err < best_err) { + best_err = err; + best_i = i; + } + } + + return static_cast<uint8_t>(best_i | sign_bit); +} + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static const uint3 init_fastdiv_values(uint64_t d_64) { + GGML_ASSERT(d_64 != 0); + GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max()); + + uint32_t d = (uint32_t)d_64; + + // compute L = ceil(log2(d)); + uint32_t L = 0; + while (L < 32 && (uint32_t{ 1 } << L) < d) { + L++; + } + + uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1); + // pack divisor as well to reduce error surface + return make_uint3(mp, L, d); +} + +static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> + // fastdiv_values.z is unused and optimized away by the compiler. + // Compute high 32 bits of n * mp + const uint32_t hi = __umulhi(n, fastdiv_values.x); + // add n, apply bit shift + return (hi + n) >> fastdiv_values.y; +} + +static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values) + return n - fastdiv(n, fastdiv_values) * fastdiv_values.z; +} + +// Calculate both division and modulo at once, returns <n/divisor, n%divisor> +static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) { + // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values) + const uint32_t div_val = fastdiv(n, fastdiv_values); + const uint32_t mod_val = n - div_val * fastdiv_values.z; + return make_uint2(div_val, mod_val); +} + +typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); + +static __device__ __forceinline__ float get_alibi_slope( + const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1 +) { + if (max_bias <= 0.0f) { + return 1.0f; + } + const float base = h < n_head_log2 ? m0 : m1; + const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + + return powf(base, exph); +} + +template <ggml_type type> +struct ggml_cuda_type_traits; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_F16> { + static constexpr int qk = 1; + static constexpr int qr = 1; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> { + static constexpr int qk = QK4_0; + static constexpr int qr = QR4_0; + static constexpr int qi = QI4_0; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> { + static constexpr int qk = QK4_1; + static constexpr int qr = QR4_1; + static constexpr int qi = QI4_1; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> { + static constexpr int qk = QK5_0; + static constexpr int qr = QR5_0; + static constexpr int qi = QI5_0; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> { + static constexpr int qk = QK5_1; + static constexpr int qr = QR5_1; + static constexpr int qi = QI5_1; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> { + static constexpr int qk = QK8_0; + static constexpr int qr = QR8_0; + static constexpr int qi = QI8_0; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> { + static constexpr int qk = QK_MXFP4; + static constexpr int qr = QR_MXFP4; + static constexpr int qi = QI_MXFP4; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_K; + static constexpr int qi = QI2_K; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_K; + static constexpr int qi = QI3_K; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_K; + static constexpr int qi = QI4_K; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR5_K; + static constexpr int qi = QI5_K; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> { + static constexpr int qk = QK_K; + static constexpr int qr = QR6_K; + static constexpr int qi = QI6_K; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XXS; + static constexpr int qi = QI2_XXS; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_XS; + static constexpr int qi = QI2_XS; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> { + static constexpr int qk = QK_K; + static constexpr int qr = QR2_S; + static constexpr int qi = QI2_S; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_XXS; + static constexpr int qi = QI3_XXS; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_S; + static constexpr int qi = QI1_S; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> { + static constexpr int qk = QK_K; + static constexpr int qr = QR1_M; + static constexpr int qi = QI1_M; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> { + static constexpr int qk = QK4_NL; + static constexpr int qr = QR4_NL; + static constexpr int qi = QI4_NL; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> { + static constexpr int qk = QK_K; + static constexpr int qr = QR4_XS; + static constexpr int qi = QI4_XS; +}; + +template<> +struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> { + static constexpr int qk = QK_K; + static constexpr int qr = QR3_S; + static constexpr int qi = QI3_S; +}; + +////////////////////// + +struct ggml_cuda_device_info { + int device_count; + + struct cuda_device_info { + int cc; // compute capability + int nsm; // number of streaming multiprocessors + size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) + bool integrated; // Device is integrated as opposed to discrete + bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory + size_t total_vram; + int warp_size; // Number of threads in a dispatch + bool supports_cooperative_launch; // whether cooperative launch is supported + }; + + cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; + + std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {}; +}; + +const ggml_cuda_device_info & ggml_cuda_info(); + +void ggml_cuda_set_device(int device); +int ggml_cuda_get_device(); + +struct ggml_cuda_pool { + virtual ~ggml_cuda_pool() = default; + + virtual void * alloc(size_t size, size_t * actual_size) = 0; + virtual void free(void * ptr, size_t size) = 0; +}; + +template<typename T> +struct ggml_cuda_pool_alloc { + ggml_cuda_pool * pool = nullptr; + T * ptr = nullptr; + size_t actual_size = 0; + + ggml_cuda_pool_alloc() = default; + + explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) { + } + + ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) { + alloc(size); + } + + ~ggml_cuda_pool_alloc() { + if (ptr != nullptr) { + pool->free(ptr, actual_size); + } + } + + // size is in number of elements + T * alloc(size_t size) { + GGML_ASSERT(pool != nullptr); + GGML_ASSERT(ptr == nullptr); + ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size); + return ptr; + } + + T * alloc(ggml_cuda_pool & pool, size_t size) { + this->pool = &pool; + return alloc(size); + } + + T * get() { + return ptr; + } + + ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete; + ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete; + ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete; + ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete; +}; + + +// backend interface + +struct ggml_tensor_extra_gpu { + void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors + cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs +}; + + +#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS) +#define USE_CUDA_GRAPH +#endif + +struct ggml_cuda_graph_node_properties { + void * node_data; + ggml_op node_op; + enum ggml_type node_type; + int32_t flags; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_data[GGML_MAX_SRC]; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; +}; + +static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial"); + +struct ggml_cuda_graph { +#ifdef USE_CUDA_GRAPH + ~ggml_cuda_graph() { + if (instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(instance)); + } + if (graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph)); + } + } + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t num_nodes = 0; + std::vector<cudaGraphNode_t> nodes; + bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + int number_consecutive_updates = 0; + std::vector<ggml_cuda_graph_node_properties> props; + + // these are extra tensors (inputs) that participate in the ggml graph but are not nodes + // they properties also have to match in order to be able to safely reuse a CUDA graph + // ref: https://github.com/ggml-org/llama.cpp/pull/18583 + // ref: https://github.com/ggml-org/llama.cpp/pull/19165 + std::vector<ggml_cuda_graph_node_properties> extra; + + void record_update(bool use_graph, bool update_required) { + if (use_graph && update_required) { + number_consecutive_updates++; + } else { + number_consecutive_updates = 0; + } + if (number_consecutive_updates >= 4) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + disable_due_to_too_many_updates = true; + } + } + + bool is_enabled() const { + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + } +#endif +}; + +struct ggml_cuda_concurrent_event { + std::vector<cudaEvent_t> join_events; + cudaEvent_t fork_event = nullptr; + + int n_streams = 0; + std::unordered_map<const ggml_tensor *, int> stream_mapping; + + // Original order of nodes in this concurrent region (before interleaving) + // Used to restore grouping for fusion within streams + std::vector<const ggml_tensor *> original_order; + + const ggml_tensor * join_node; + + ggml_cuda_concurrent_event() = default; + + ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete; + ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete; + + explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) { + join_events.resize(n_streams); + + for (size_t i = 0; i < join_events.size(); ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming)); + } + + CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming)); + } + + ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept + : join_events(std::move(other.join_events)) + , fork_event(other.fork_event) + , n_streams(other.n_streams) + , stream_mapping(std::move(other.stream_mapping)) + , original_order(std::move(other.original_order)) + , join_node(other.join_node) { + other.fork_event = nullptr; + } + + // 1. check if any branches write to overlapping memory ranges (except the join node) + // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event + // we assume all nodes have the same buffer + bool is_valid() const { + std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges; + write_ranges.resize(n_streams); + + // get join_node's memory range to exclude from overlap checking. + // multiple nodes can use join_node's buffer; we synchronize on the join node. + const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node; + const int64_t join_start = (int64_t) join_t->data; + const int64_t join_end = join_start + ggml_nbytes(join_t); + + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer. + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // concurrent streams begin from 1 + write_ranges[stream - 1].emplace_back(t_start, t_end); + } + + for (int i = 0; i < n_streams; ++i) { + // sorts first by start then by end of write range + std::sort(write_ranges[i].begin(), write_ranges[i].end()); + } + + bool writes_overlap = false; + bool dependent_srcs = false; + for (const auto & [tensor, stream] : stream_mapping) { + const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; + const int64_t t_start = (int64_t) t->data; + const int64_t t_end = t_start + ggml_nbytes(t); + + // skip tensors that overlap with join_node's buffer + if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) { + continue; + } + + // check if this buffer's write data overlaps with another stream's + std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end); + for (int i = 0; i < n_streams; ++i) { + if (i == stream - 1) { + continue; + } + auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range); + + if (it != write_ranges[i].end()) { + const std::pair<int64_t, int64_t> & other = *it; + + // std::lower_bound returns the first element where other >= data_range (lexicographically). + // This guarantees other.first >= data_range.first. + // Therefore, overlap occurs iff other.first < data_range.second + // (i.e., the other range starts before this range ends). + if (other.first < data_range.second) { + GGML_LOG_DEBUG("Writes overlap for %s", tensor->name); + writes_overlap = true; + break; + } + } + } + + //check if all srcs are either in branch or don't have a branch + for (int i = 0; i < GGML_MAX_SRC; ++i) { + if (!tensor->src[i]) { + continue; + } + + auto it = stream_mapping.find(tensor->src[i]); + + if (it == stream_mapping.end()) { + continue; + } + + if (it->second != stream) { + dependent_srcs = true; + break; + } + } + + if (dependent_srcs || writes_overlap) { + break; + } + } + + return !writes_overlap && !dependent_srcs; + } + + ~ggml_cuda_concurrent_event() { + if (fork_event != nullptr) { + CUDA_CHECK(cudaEventDestroy(fork_event)); + } + for (cudaEvent_t e : join_events) { + if (e != nullptr) { + CUDA_CHECK(cudaEventDestroy(e)); + } + } + } +}; + +struct ggml_cuda_stream_context { + std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events; + + void reset() { + concurrent_events.clear(); + } +}; + +struct ggml_backend_cuda_context { + int device; + std::string name; + cudaEvent_t copy_event = nullptr; + + cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; + cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; + + int curr_stream_no = 0; + +#ifdef USE_CUDA_GRAPH + // Map from first_node_ptr to cuda_graph - allows multiple graphs per context + // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) + std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs; + + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + auto it = cuda_graphs.find(first_node_ptr); + if (it == cuda_graphs.end()) { + cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>(); + return cuda_graphs[first_node_ptr].get(); + } + return it->second.get(); + } + + // Check if any CUDA graph is enabled for this context (used by kernels that need to know + // if graphs are in use without having access to the specific graph key) + bool any_cuda_graph_enabled() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->is_enabled()) { + return true; + } + } + return false; + } + + // Check if any CUDA graph has an instance for this context + bool any_cuda_graph_has_instance() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->instance != nullptr) { + return true; + } + } + return false; + } +#endif // USE_CUDA_GRAPH + + explicit ggml_backend_cuda_context(int device) : + device(device), + name(GGML_CUDA_NAME + std::to_string(device)) { + } + + ggml_cuda_stream_context concurrent_stream_context; + + ~ggml_backend_cuda_context(); + + cudaStream_t stream(int device, int stream) { + if (streams[device][stream] == nullptr) { + ggml_cuda_set_device(device); + CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking)); + } + return streams[device][stream]; + } + + cudaStream_t stream() { return stream(device, curr_stream_no); } + + ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; } + + cublasHandle_t cublas_handle(int device) { + if (cublas_handles[device] == nullptr) { + ggml_cuda_set_device(device); + CUBLAS_CHECK(cublasCreate(&cublas_handles[device])); + CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH)); + } + return cublas_handles[device]; + } + + cublasHandle_t cublas_handle() { + return cublas_handle(device); + } + + // pool + std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; + + static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no); + + ggml_cuda_pool & pool(int device) { + if (pools[device][curr_stream_no] == nullptr) { + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no); + } + return *pools[device][curr_stream_no]; + } + + ggml_cuda_pool & pool() { + return pool(device); + } +}; + +struct ggml_cuda_mm_fusion_args_host { + const ggml_tensor * x_bias = nullptr; + const ggml_tensor * gate = nullptr; + const ggml_tensor * gate_bias = nullptr; + ggml_glu_op glu_op; +}; +struct ggml_cuda_mm_fusion_args_device { + const void * x_bias = nullptr; + const void * gate = nullptr; + const void * gate_bias = nullptr; + ggml_glu_op glu_op; +}; diff --git a/llama.cpp/ggml/src/ggml-cuda/concat.cu b/llama.cpp/ggml/src/ggml-cuda/concat.cu new file mode 100644 index 0000000..e9ffd27 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/concat.cu @@ -0,0 +1,221 @@ +#include "concat.cuh" + +// contiguous kernels +static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + + if (nidx < ne00) { // src0 + int offset_src = + nidx + + blockIdx.y * ne00 + + blockIdx.z * ne00 * gridDim.y; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + (nidx - ne00) + + blockIdx.y * (ne0 - ne00) + + blockIdx.z * (ne0 - ne00) * gridDim.y; + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + + if (blockIdx.y < (unsigned)ne01) { // src0 + int offset_src = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * ne01; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + + (blockIdx.y - ne01) * ne0 + + blockIdx.z * ne0 * (gridDim.y - ne01); + dst[offset_dst] = y[offset_src]; + } +} + +static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + if (nidx >= ne0) { + return; + } + + int offset_dst = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + + if (blockIdx.z < (unsigned)ne02) { // src0 + int offset_src = + nidx + + blockIdx.y * ne0 + + blockIdx.z * ne0 * gridDim.y; + dst[offset_dst] = x[offset_src]; + } else { + int offset_src = + nidx + + blockIdx.y * ne0 + + (blockIdx.z - ne02) * ne0 * gridDim.y; + dst[offset_dst] = y[offset_src]; + } +} + +static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2); + if (dim == 0) { + concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00); + return; + } + if (dim == 1) { + concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01); + return; + } + concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02); +} + +// non-contiguous kernel (slow) +template <int dim> +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) + concat_f32_non_cont( + const char * src0, + const char * src1, + char * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne03, + uint64_t nb00, + uint64_t nb01, + uint64_t nb02, + uint64_t nb03, + int64_t /*ne10*/, + int64_t /*ne11*/, + int64_t /*ne12*/, + int64_t /*ne13*/, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + uint64_t nb13, + int64_t ne0, + int64_t /*ne1*/, + int64_t /*ne2*/, + int64_t /*ne3*/, + uint64_t nb0, + uint64_t nb1, + uint64_t nb2, + uint64_t nb3){ + static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + const float * x; + + for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + if constexpr (dim == 0) { + x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + } else if constexpr (dim == 1) { + x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + } else if constexpr (dim == 2) { + x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + } else if constexpr (dim == 3) { + x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + } + } + + float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} + + +void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + + if (dim != 3) { + for (int i3 = 0; i3 < dst->ne[3]; i3++) { + concat_f32_cuda( + src0_d + i3 * (src0->nb[3] / 4), + src1_d + i3 * (src1->nb[3] / 4), + dst_d + i3 * ( dst->nb[3] / 4), + src0->ne[0], src0->ne[1], src0->ne[2], + dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); + } + } else { + const size_t size0 = ggml_nbytes(src0); + const size_t size1 = ggml_nbytes(src1); + + CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + } + } else { + dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); + auto launch_kernel = [&](auto dim) { + concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( + (const char *) src0->data, (const char *) src1->data, (char *) dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]); + }; + switch (dim) { + case 0: + launch_kernel(std::integral_constant<int, 0>{}); + break; + case 1: + launch_kernel(std::integral_constant<int, 1>{}); + break; + case 2: + launch_kernel(std::integral_constant<int, 2>{}); + break; + case 3: + launch_kernel(std::integral_constant<int, 3>{}); + break; + default: + GGML_ABORT("Invalid dim: %d", dim); + break; + } + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/concat.cuh b/llama.cpp/ggml/src/ggml-cuda/concat.cuh new file mode 100644 index 0000000..aa506a0 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/concat.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CONCAT_BLOCK_SIZE 256 + +void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu new file mode 100644 index 0000000..8418ba6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu @@ -0,0 +1,86 @@ +#include "conv-transpose-1d.cuh" + +static __global__ void conv_transpose_1d_kernel( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst) { + int global_index = threadIdx.x + blockIdx.x * blockDim.x; + if (global_index >= output_size) { + return; + } + + int out_index = global_index / dst_ne0; + + float accumulator = 0; + + for (int c = 0; c < src0_ne2; c++) { + int idx = global_index % dst_ne0; + + int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0); + int input_offset = src1_ne0 * c; + + for (int i = 0; i < src1_ne0; i++) { + if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) { + continue; + } + int weight_idx = idx - i*s0; + + float kernel_weight = src0[kernel_offset + weight_idx]; + float input_value = src1[input_offset+i]; + + accumulator += kernel_weight * input_value; + } + } + dst[global_index] = accumulator; + GGML_UNUSED_VARS(p0, d0, src0_ne3, src1_ne3, dst_ne3, src1_ne1, dst_ne1, src1_ne2, dst_ne2); +} + +static void conv_transpose_1d_f32_f32_cuda( + const int s0, const int p0, const int d0, const int output_size, + const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3, + const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3, + const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3, + const float * src0, const float * src1, float * dst, + cudaStream_t stream) { + + const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE; + conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>( + s0,p0,d0,output_size, + src0_ne0, src0_ne1, src0_ne2, src0_ne3, + src1_ne0, src1_ne1, src1_ne2, src1_ne3, + dst_ne0, dst_ne1, dst_ne2, dst_ne3, + src0,src1, dst); +} + +void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const int32_t * opts = (const int32_t *)dst->op_params; + + const int s0 = opts[0]; + const int p0 = 0;//opts[3]; + const int d0 = 1;//opts[4]; + + const int64_t output_size = ggml_nelements(dst); + + conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + src0_d, src1_d, dst_d, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cuh b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cuh new file mode 100644 index 0000000..6c2cf66 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu new file mode 100644 index 0000000..7583233 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu @@ -0,0 +1,161 @@ +#include "conv2d-dw.cuh" + +struct conv_params { + int in_w, in_h; + int out_w, out_h; + int kernel_w, kernel_h; + int stride_x, stride_y; + int padding_x, padding_y; + int dilation_x, dilation_y; + int channels, batches; +}; + +struct kernel_bounds { + int y_min, y_max; + int x_min, x_max; +}; + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) { + kernel_bounds bounds; + bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y); + bounds.y_max = + min(params.kernel_h, + (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y); + bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x); + bounds.x_max = + min(params.kernel_w, + (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) { + return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x; + } + + __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) { + return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx; + } + + __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) { + return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h + + y * params.out_w + x; + } + + __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y, + int & out_x) { + out_x = global_idx % params.out_w; + out_y = (global_idx / params.out_w) % params.out_h; + c = (global_idx / (params.out_w * params.out_h)) % params.channels; + n = global_idx / (params.out_w * params.out_h * params.channels); + } +}; + +struct cwhn_layout { + __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) { + return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c; + } + + __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) { + return (ky * params.kernel_w + kx) * params.channels + c; + } + + __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) { + return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) + + x * params.channels + c; + } + + __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y, + int & out_x) { + c = global_idx % params.channels; + out_x = (global_idx / params.channels) % params.out_w; + out_y = (global_idx / (params.channels * params.out_w)) % params.out_h; + n = global_idx / (params.channels * params.out_w * params.out_h); + } +}; + +template <typename T, typename Layout> +__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output, + const int in_w, const int in_h, const int out_w, const int out_h, + const int kernel_w, const int kernel_h, const int stride_x, const int stride_y, + const int padding_x, const int padding_y, const int dilation_x, const int dilation_y, + const int channels, const int batches) { + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + const int total_elements = batches * channels * out_h * out_w; + + if (global_idx >= total_elements) { + return; + } + + conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, + stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches }; + + int batch_idx, channel_idx, out_y_idx, out_x_idx; + Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx); + + T accumulator = 0; + kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params); + + for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) { + int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y); + + for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) { + int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x); + + const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)]; + const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)]; + + accumulator += input_val * kernel_val; + } + } + + output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator; +} + +void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + + GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + const float * w_d = (const float *) kernel->data; + const float * x_d = (const float *) input->data; + float * y_d = (float *) dst->data; + + const int32_t * p = (const int32_t *) dst->op_params; + const int stride_x = p[0]; + const int stride_y = p[1]; + const int padding_x = p[2]; + const int padding_y = p[3]; + const int dilation_x = p[4]; + const int dilation_y = p[5]; + + const int in_w = input->ne[0]; + const int in_h = input->ne[1]; + const int kernel_w = kernel->ne[0]; + const int kernel_h = kernel->ne[1]; + const int out_w = dst->ne[0]; + const int out_h = dst->ne[1]; + const int channels = dst->ne[2]; + const int batches = dst->ne[3]; + + cudaStream_t st = ctx.stream(); + + const int total = batches * channels * out_h * out_w; + const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE; + + if (ggml_is_contiguous(input)) { + conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>( + x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y, + dilation_x, dilation_y, channels, batches); + } else if (ggml_is_contiguous_channels(input)) { + conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>( + x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y, + dilation_x, dilation_y, channels, batches); + } else { + GGML_ABORT("Unsupported memory layout for conv_2d_dw"); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh new file mode 100644 index 0000000..b5d5a69 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_DW_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu new file mode 100644 index 0000000..03224e4 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu @@ -0,0 +1,91 @@ +#include <algorithm> + +#include "conv2d-transpose.cuh" +#include "ggml.h" + +__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, + float * __restrict__ output, const int in_w, const int in_h, const int out_w, + const int out_h, const int kernel_w, const int kernel_h, const int stride, + const int c_in, const int c_out, const int batches) { + const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + const int total_elements = out_w * out_h * c_out * batches; + + if (global_idx >= total_elements) { + return; + } + + const int out_x_idx = global_idx % out_w; + const int out_y_idx = (global_idx / out_w) % out_h; + const int c_idx = (global_idx / (out_w * out_h)) % c_out; + const int n_idx = global_idx / (out_w * out_h * c_out); + + float accumulator = 0; + // For each output idx, find the inputs that contribute to it by checking stride alignment and bounds + + for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { + for (int kh = 0; kh < kernel_h; ++kh) { + int in_y = out_y_idx - kh; + if (in_y < 0 || in_y % stride) continue; + in_y /= stride; + if (in_y >= in_h) continue; + + for (int kw = 0; kw < kernel_w; ++kw) { + int in_x = out_x_idx - kw; + if (in_x < 0 || in_x % stride) continue; + in_x /= stride; + if (in_x >= in_w) continue; + + const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; + const int kernel_idx = + (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; + + float input_val = input[input_idx]; + half kern_val = kernel[kernel_idx]; + + accumulator += input_val * (float) kern_val; + } + } + } + + output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator; +} + +//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in) +void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + + GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + + const float * input_data = (const float *) input->data; + float * output_data = (float *) dst->data; + const half * kernel_data = (const half *) kernel->data; + + const int input_w = input->ne[0]; + const int input_h = input->ne[1]; + const int output_w = dst->ne[0]; + const int output_h = dst->ne[1]; + const int channels_in = input->ne[2]; + const int channels_out = kernel->ne[2]; + const int kernel_w = kernel->ne[0]; + const int kernel_h = kernel->ne[1]; + const int stride = dst->op_params[0]; + const int batches = input->ne[3]; + + GGML_ASSERT(channels_in == kernel->ne[3]); + GGML_ASSERT(stride > 0); + + cudaStream_t st = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(input)); + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int total = (output_w * output_h * channels_out * batches); + const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; + + conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>( + input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, + channels_in, channels_out, batches); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh new file mode 100644 index 0000000..c9430b2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" + +#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 +void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d.cu b/llama.cpp/ggml/src/ggml-cuda/conv2d.cu new file mode 100644 index 0000000..142dd66 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv2d.cu @@ -0,0 +1,166 @@ +#include "conv2d.cuh" +#include "convert.cuh" + +struct conv_params { + const int64_t IW, IH; + const int64_t OW, OH; + const int64_t KW, KH; + const int64_t ST_X, ST_Y; + const int64_t PD_X, PD_Y; + const int64_t DL_X, DL_Y; + const int64_t IC, OC; + const int64_t B; + const int64_t TOTAL; +}; + +struct kernel_bounds { + int64_t y_min, y_max; + int64_t x_min, x_max; +}; + +__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) { + return (a > b) ? a : b; +} + +__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) { + return (a < b) ? a : b; +} + +__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) { + kernel_bounds bounds; + bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y); + bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X); + return bounds; +} + +__device__ __forceinline__ int calculate_input_coord(int64_t out_coord, + int64_t kern_coord, + int64_t stride, + int64_t dilation, + int64_t padding) { + return out_coord * stride + kern_coord * dilation - padding; +} + +struct whcn_layout { + __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x; + } + + __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) { + return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx; + } + + __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) { + return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x; + } + + __device__ static void unpack_indices(int64_t global_idx, + const conv_params & P, + int64_t & n, + int64_t & c, + int64_t & out_y, + int64_t & out_x) { + out_x = global_idx % P.OW; + out_y = (global_idx / P.OW) % P.OH; + c = (global_idx / (P.OW * P.OH)) % P.OC; + n = global_idx / (P.OW * P.OH * P.OC); + } +}; + +template <typename T, typename Layout> +static __global__ void conv2d_kernel(const float * __restrict__ input, + const T * __restrict__ kernel, + float * __restrict__ output, + const conv_params P) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= P.TOTAL) { + return; + } + + int64_t n, c_out, out_y, out_x; + Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); + + float acc = 0.0f; + + for (int64_t c_in = 0; c_in < P.IC; ++c_in) { + kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); + + for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) { + const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y); + + for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { + const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); + + const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)]; + const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)]; + acc += (input_val * ggml_cuda_cast<float>(kernel_val)); + } + } + } + + // [N, OC, OH, OW] + output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc; +} + +template <typename T> +static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE; + conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P); +} + +static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda<half>(X_D, K_D, Y_D, P, st); +} + +static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) { + conv2d_cuda<float>(X_D, K_D, Y_D, P, st); +} + +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * kernel = dst->src[0]; + const ggml_tensor * input = dst->src[1]; + float * K_D = (float *) kernel->data; + const float * X_D = (const float *) input->data; + float * Y_D = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous(kernel)); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + + // same number of input channels + GGML_ASSERT(input->ne[2] == kernel->ne[2]); + + cudaStream_t st = ctx.stream(); + + const int32_t * p = (const int32_t *) dst->op_params; + const int ST_X = p[0]; // stride_x + const int ST_Y = p[1]; // stride_y + const int PD_X = p[2]; // padding_x + const int PD_Y = p[3]; // padding_y + const int DL_X = p[4]; // dilation_x + const int DL_Y = p[5]; // dilation_y + + // No cwhn + GGML_ASSERT(p[6] == false); + + const int IW = input->ne[0]; // input_w + const int IH = input->ne[1]; // input_h + const int OW = dst->ne[0]; // output_w + const int OH = dst->ne[1]; // output_h + const int KW = kernel->ne[0]; // kernel_w + const int KH = kernel->ne[1]; // kernel_h + const int IC = input->ne[2]; // input_channels + const int OC = kernel->ne[3]; // ouptut_chanles + const int B = input->ne[3]; // n_batches + + const int64_t total = B * OC * OH * OW; + conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total }; + + if (kernel->type == GGML_TYPE_F16) { + conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st); + } else { + conv2d_cuda_f32(X_D, K_D, Y_D, params, st); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh b/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh new file mode 100644 index 0000000..ce4802c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh @@ -0,0 +1,5 @@ +#pragma once +#include "common.cuh" + +#define CUDA_CONV2D_BLOCK_SIZE 256 +void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/convert.cu b/llama.cpp/ggml/src/ggml-cuda/convert.cu new file mode 100644 index 0000000..ba3d4ee --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/convert.cu @@ -0,0 +1,825 @@ +#include "convert.cuh" +#include "dequantize.cuh" + +#include <cstdint> + +#define CUDA_Q8_0_NE_ALIGN 2048 + +template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); + + if (i00 >= ne00) { + return; + } + + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x); + y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y); +} + +template <bool need_check> +static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL + constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE; + + const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x; + const int * x0 = ((int *) vx) + blockIdx.x * nint; + half2 * y2 = (half2 *) (y + i0); + + __shared__ int vals[nint]; + +#pragma unroll + for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) { + if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) { + break; + } + + const int ix = ix0 + threadIdx.x; + vals[ix] = x0[ix]; + } + + __syncthreads(); + +#pragma unroll + for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) { + if (need_check && i0 + iy + 2*threadIdx.x >= k) { + return; + } + + const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0); + const half d = *b0; + const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)]; + + y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d)); + } +#else + GGML_UNUSED_VARS(vx, y, k); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL +} + +template<typename dst_t> +static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + + const int64_t i = blockIdx.x; + + // assume 32 threads + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_0 * x = (const block_q4_0 *)vx + ib; + const float d = __half2float(x->d); + const float dm = -8*d; + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d * (q[l] & 0xF) + dm; + y[l+16] = d * (q[l] >> 4) + dm; + } +} + +template<typename dst_t> +static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + + const int64_t i = blockIdx.x; + + // assume 32 threads + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_1 * x = (const block_q4_1 *)vx + ib; + const float2 d = __half22float2(x->dm); + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d.x * (q[l] & 0xF) + d.y; + y[l+16] = d.x * (q[l] >> 4) + d.y; + } +} + +//================================== k-quants + +template<typename dst_t> +static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_q2_K * x = (const block_q2_K *) vx; + + const int64_t tid = threadIdx.x; + const int64_t n = tid/32; + const int64_t l = tid - 32*n; + const int64_t is = 8*n + l/16; + + const uint8_t q = x[i].qs[32*n + l]; + dst_t * y = yy + i*QK_K + 128*n; + + float dall = __low2half(x[i].dm); + float dmin = __high2half(x[i].dm); + y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); + y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); + y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); + y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4); +} + +template<typename dst_t> +static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_q3_K * x = (const block_q3_K *) vx; + + const int64_t r = threadIdx.x/4; + const int64_t tid = r/2; + const int64_t is0 = r%2; + const int64_t l0 = 16*is0 + 4*(threadIdx.x%4); + const int64_t n = tid / 4; + const int64_t j = tid - 4*n; + + uint8_t m = 1 << (4*n + j); + int64_t is = 8*n + 2*j + is0; + int shift = 2*j; + + int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) : + (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4); + float d_all = x[i].d; + float dl = d_all * (us - 32); + + dst_t * y = yy + i*QK_K + 128*n + 32*j; + const uint8_t * q = x[i].qs + 32*n; + const uint8_t * hm = x[i].hmask; + + for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); +} + +static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { + if (j < 4) { + d = q[j] & 63; m = q[j + 4] & 63; + } else { + d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +template<typename dst_t> +static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q4_K * x = (const block_q4_K *) vx; + + const int64_t i = blockIdx.x; + + // assume 32 threads + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t is = 2*il; + const int64_t n = 4; + + dst_t * y = yy + i*QK_K + 64*il + n*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * q = x[i].qs + 32*il + n*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + for (int l = 0; l < n; ++l) { + y[l + 0] = d1 * (q[l] & 0xF) - m1; + y[l +32] = d2 * (q[l] >> 4) - m2; + } +} + +template<typename dst_t> +static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q5_K * x = (const block_q5_K *) vx; + + const int64_t i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int64_t tid = threadIdx.x; + const int64_t il = tid/16; // il is in 0...3 + const int64_t ir = tid%16; // ir is in 0...15 + const int64_t is = 2*il; // is is in 0...6 + + dst_t * y = yy + i*QK_K + 64*il + 2*ir; + + const float dall = __low2half(x[i].dm); + const float dmin = __high2half(x[i].dm); + + const uint8_t * ql = x[i].qs + 32*il + 2*ir; + const uint8_t * qh = x[i].qh + 2*ir; + + uint8_t sc, m; + get_scale_min_k4(is + 0, x[i].scales, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, x[i].scales, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2*il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +} + +template<typename dst_t> +static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const block_q6_K * x = (const block_q6_K *) vx; + + const int64_t i = blockIdx.x; + + // assume 64 threads - this is very slightly better than the one below + const int64_t tid = threadIdx.x; + const int64_t ip = tid/32; // ip is 0 or 1 + const int64_t il = tid - 32*ip; // 0...32 + const int64_t is = 8*ip + il/16; + + dst_t * y = yy + i*QK_K + 128*ip + il; + + const float d = x[i].d; + + const uint8_t * ql = x[i].ql + 64*ip + il; + const uint8_t qh = x[i].qh[32*ip + il]; + const int8_t * sc = x[i].scales + is; + + y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32); + y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32); + y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32); + y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); +} + +template<typename dst_t> +static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * aux8 = (const uint8_t *)q2; + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +} + +template<typename dst_t> +static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +} + +template<typename dst_t> +static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq2_s * x = (const block_iq2_s *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = x[i].qs[QK_K/8+4*ib+il]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +} + +template<typename dst_t> +static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * q3 = x[i].qs + 8*ib; + const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } +} + +template<typename dst_t> +static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq3_s * x = (const block_iq3_s *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * qs = x[i].qs + 8*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256))); + const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256))); + const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)); + const uint8_t signs = x[i].signs[4*ib + il]; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } +} + +template<typename dst_t> +static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq1_s * x = (const block_iq1_s *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA; + const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1); + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = d * (q[j] + delta); + } +} + +template<typename dst_t> +static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq1_m * x = (const block_iq1_m *) vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * sc = (const uint16_t *)x[i].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4); + const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1); + const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA; + uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32; + grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)]; + grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; + grid32[0] &= 0x0f0f0f0f; + for (int j = 0; j < 8; ++j) { + y[j] = d * (q[j] + delta); + } +} + +template<typename dst_t> +static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL); + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + const float d = (float)x[ib].d; + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; + } +} + +template<typename dst_t> +static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) { + const int64_t i = blockIdx.x; + const block_iq4_xs * x = (const block_iq4_xs *)vx; + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[i].qs + 16*ib + 4*il; + const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf]; + y[j+16] = d * kvalues_iq4nl[q4[j] >> 4]; + } +} + +template<typename dst_t> +static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) { + + const int64_t i = blockIdx.x; + const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4); + + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; // 0...3 + const int64_t ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 4*il; + const uint8_t * q4 = x[ib].qs + 4*il; + const float d = ggml_cuda_e8m0_to_fp32(x[ib].e); + for (int j = 0; j < 4; ++j) { + y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f; + y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f; + } +} + +template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> +static void dequantize_block_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> +static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { + dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream); +} + +static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN; + if (k % CUDA_Q8_0_NE_ALIGN == 0) { + const bool need_check = false; + dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k); + } else { + const bool need_check = true; + dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k); + } +} + +template<typename dst_t> +static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb32 = k / 32; + const int nb = (k + 255) / 256; + dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32); +} + +template<typename dst_t> +static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb32 = k / 32; + const int nb = (k + 255) / 256; + dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32); +} + +template<typename dst_t> +static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = k / QK_K; + dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); +} + +template<typename dst_t> +static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb = (k + QK_K - 1) / QK_K; + dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y); +} + +template <typename src_t, typename dst_t> +static __global__ void convert_unary( + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i00 >= ne00) { + return; + } + + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + + const src_t * x = (const src_t *) vx; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast<dst_t>(x[ix]); +} + +template <typename src_t, typename dst_t> +static void convert_unary_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +template <typename src_t, typename dst_t> +static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream); +} + +to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cont_cuda<float>; + case GGML_TYPE_F16: + return convert_unary_cont_cuda<half>; + default: + return nullptr; + } +} + +to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>; + case GGML_TYPE_Q5_1: + return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q8_0: + if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) { + return dequantize_block_q8_0_f16_cuda; + } + return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_cuda; + case GGML_TYPE_IQ2_S: + return dequantize_row_iq2_s_cuda; + case GGML_TYPE_IQ3_XXS: + return dequantize_row_iq3_xxs_cuda; + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_cuda; + case GGML_TYPE_IQ1_M: + return dequantize_row_iq1_m_cuda; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_IQ4_XS: + return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ3_S: + return dequantize_row_iq3_s_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; + case GGML_TYPE_F32: + return convert_unary_cont_cuda<float>; + case GGML_TYPE_BF16: + return convert_unary_cont_cuda<nv_bfloat16>; + default: + return nullptr; + } +} + +to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>; + case GGML_TYPE_Q5_1: + return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q8_0: + return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_cuda; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_cuda; + case GGML_TYPE_IQ2_S: + return dequantize_row_iq2_s_cuda; + case GGML_TYPE_IQ3_XXS: + return dequantize_row_iq3_xxs_cuda; + case GGML_TYPE_IQ1_S: + return dequantize_row_iq1_s_cuda; + case GGML_TYPE_IQ1_M: + return dequantize_row_iq1_m_cuda; + case GGML_TYPE_IQ4_NL: + return dequantize_row_iq4_nl_cuda; + case GGML_TYPE_IQ4_XS: + return dequantize_row_iq4_xs_cuda; + case GGML_TYPE_IQ3_S: + return dequantize_row_iq3_s_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; + case GGML_TYPE_F16: + return convert_unary_cont_cuda<half>; + case GGML_TYPE_BF16: + return convert_unary_cont_cuda<nv_bfloat16>; + default: + return nullptr; + } +} + +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda<float>; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; + case GGML_TYPE_BF16: + return convert_unary_cuda<nv_bfloat16>; + default: + return nullptr; + } +} + +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_cuda<float, nv_bfloat16>; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; + case GGML_TYPE_F16: + return convert_unary_cuda<half, nv_bfloat16>; + default: + return nullptr; + } +} + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_F16: + return convert_unary_cuda<half, float>; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>; + case GGML_TYPE_BF16: + return convert_unary_cuda<nv_bfloat16, float>; + default: + return nullptr; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/convert.cuh b/llama.cpp/ggml/src/ggml-cuda/convert.cuh new file mode 100644 index 0000000..09f9a33 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/convert.cuh @@ -0,0 +1,56 @@ +#pragma once +#include "common.cuh" + +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 + +template<typename T> +using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream); + +typedef to_t_cuda_t<float> to_fp32_cuda_t; +typedef to_t_cuda_t<half> to_fp16_cuda_t; +typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t; + +to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type); + +to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); + +to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); + +// TODO more general support for non-contiguous inputs + +template<typename T> +using to_t_nc_cuda_t = void (*)(const void * x, T * y, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); + +typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t; +typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t; +typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t; + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); + +template<typename dst_t, typename src_t> + __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { + if constexpr (std::is_same_v<dst_t, src_t>) { + return x; + } else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) { + return __float2bfloat16(float(x)); + } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) { + return __bfloat162float(x); + } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) { + // bypass compile error on cuda 12.0.1 +#ifdef GGML_USE_HIP + return __float22bfloat162_rn(x); +#else + return {x.x, x.y}; +#endif // GGML_USE_HIP + } else if constexpr(std::is_same_v<dst_t, int32_t>) { + return int32_t(x); + } else { + return float(x); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/count-equal.cu b/llama.cpp/ggml/src/ggml-cuda/count-equal.cu new file mode 100644 index 0000000..0889811 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/count-equal.cu @@ -0,0 +1,64 @@ +#include "common.cuh" +#include "count-equal.cuh" + +#include <cstdint> + +template <typename T> +static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) { + const int64_t i0 = (int64_t) blockIdx.x*dk; + const int64_t i1 = min(i0 + dk, k); + + int nequal = 0; + + for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) { + const T xi = x[i]; + const T yi = y[i]; + nequal += xi == yi; + } + + nequal = warp_reduce_sum(nequal); + + if (threadIdx.x != 0) { + return; + } + + atomicAdd((int *) dst, nequal); +} + +void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT( dst->type == GGML_TYPE_I64); + + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + int64_t * dst_d = (int64_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; + + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int"); + const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE); + + CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream)); + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1); + + switch (src0->type) { + case GGML_TYPE_I32: { + const int * src0_d = (const int *) src0->data; + const int * src1_d = (const int *) src1->data; + count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne); + } break; + default: + GGML_ASSERT(false); + break; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/count-equal.cuh b/llama.cpp/ggml/src/ggml-cuda/count-equal.cuh new file mode 100644 index 0000000..8467da7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/count-equal.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128 + +void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/cp-async.cuh b/llama.cpp/ggml/src/ggml-cuda/cp-async.cuh new file mode 100644 index 0000000..63d0c48 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cp-async.cuh @@ -0,0 +1,57 @@ +// Simplified API for asynchronous data loading. + +#include "common.cuh" + + +static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) { +#ifdef CP_ASYNC_AVAILABLE + return __cvta_generic_to_shared(generic_ptr); +#else + GGML_UNUSED(generic_ptr); + NO_DEVICE_CODE; + return 0; +#endif // CP_ASYNC_AVAILABLE +} + +// Copies data from global to shared memory, cg == cache global. +// Both the src and dst pointers must be aligned to 16 bit. +// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int. +// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared. +// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements. +template <int preload> +static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) { + static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload"); +#ifdef CP_ASYNC_AVAILABLE +#if CUDART_VERSION >= 11040 + if (preload == 256) { + asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 128) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else if (preload == 64) { + asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } else +#endif // CUDART_VERSION >= 11040 + { + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;" + : : "r"(dst), "l"(src)); + } +#else + GGML_UNUSED(dst); + GGML_UNUSED(src); + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} + +// Makes each thread wait until its asynchronous data copies are done. +// This does NOT provide any additional synchronization. +// In particular, when copying data with multiple warps a call to __syncthreads will be needed. +static __device__ __forceinline__ void cp_async_wait_all() { +#ifdef CP_ASYNC_AVAILABLE + asm volatile("cp.async.wait_all;"); +#else + NO_DEVICE_CODE; +#endif // CP_ASYNC_AVAILABLE +} diff --git a/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh b/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh new file mode 100644 index 0000000..7697c29 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh @@ -0,0 +1,217 @@ +#pragma once + +#include "ggml-common.h" +#include "convert.cuh" + +static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { + if (x <= val[0]) return 0; + if (x >= val[n-1]) return n-1; + int ml = 0, mu = n-1; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < val[mav]) mu = mav; else ml = mav; + } + return x - val[mu-1] < val[mu] - x ? mu-1 : mu; +} + +static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK4_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_0/2 + j]*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) { + float vmin = FLT_MAX; + float vmax = -FLT_MAX; + + for (int j = 0; j < QK4_1; ++j) { + const float v = x[j]; + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = vmin; + + for (int j = 0; j < QK4_1/2; ++j) { + const float x0 = (x[0 + j] - vmin)*id; + const float x1 = (x[QK4_1/2 + j] - vmin)*id; + + const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f)); + + y->qs[j] = xi0; + y->qs[j] |= xi1 << 4; + } +} + +static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK5_0; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + uint32_t qh = 0; + for (int j = 0; j < QK5_0/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK5_0/2 + j]*id; + + const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f)); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) { + float min = x[0]; + float max = x[0]; + + for (int j = 1; j < QK5_1; ++j) { + const float v = x[j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = d ? 1.0f/d : 0.0f; + + y->dm.x = d; + y->dm.y = min; + + uint32_t qh = 0; + for (int j = 0; j < QK5_1/2; ++j) { + const float x0 = (x[0 + j] - min)*id; + const float x1 = (x[QK5_1/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2); + } + memcpy(y->qh, &qh, sizeof(qh)); +} + +static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[j]; + amax = fmaxf(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y->d = d; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[j]*id; + y->qs[j] = roundf(x0); + } +} + +static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) { + float amax = 0.0f; + float vmax = 0.0f; + + for (int j = 0; j < QK4_NL; ++j) { + const float v = x[j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = d ? 1.0f/d : 0.0f; + + float sumqx = 0, sumq2 = 0; + for (int j = 0; j < QK4_NL/2; ++j) { + const float x0 = x[0 + j]*id; + const float x1 = x[QK4_NL/2 + j]*id; + const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0); + const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1); + y->qs[j] = xi0 | (xi1 << 4); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = x[0 + j]*x[0 + j]; + const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j]; + sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + y->d = sumq2 > 0 ? sumqx/sumq2 : d; +} + +// Wrapper functions for cpy.cu compatibility +static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) { + quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) { + quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) { + quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) { + quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti); +} + +static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) { + quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti); +} + +static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { + quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); +} + +template<typename src_t, typename dst_t> +static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) { + *(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/cpy.cu b/llama.cpp/ggml/src/ggml-cuda/cpy.cu new file mode 100644 index 0000000..ee84303 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cpy.cu @@ -0,0 +1,555 @@ +#include "cpy.cuh" +#include "dequantize.cuh" +#include "cpy-utils.cuh" +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) +#include "ggml-musa/mudnn.cuh" +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY + +typedef void (*cpy_kernel_t)(const char * cx, char * cdst); + +const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks +const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available +const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows + +template <cpy_kernel_t cpy_1> +static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor + // then combine those indices with the corresponding byte offsets to get the total offsets + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + + cpy_1(cx + x_offset, cdst + dst_offset); +} + +template <typename T> +static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + + const T* src = reinterpret_cast<const T*>(cx); + T* dst = reinterpret_cast<T*>(cdst); + + const int64_t nmat = ne / (ne00 * ne01); + const int64_t n = ne00 * ne01; + + const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x; + const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset + const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; + + __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; + +#pragma unroll + for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { + + const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i; + if (imat >= nmat) + break; + +#pragma unroll + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { + if(x < ne01 && y + j < ne00){ + const int row = threadIdx.y+j; + const int col = threadIdx.x * sizeof(float)/sizeof(T); + T *tile2 = reinterpret_cast<T*>(tile[row]); + tile2[col] = src[imat*n + (y+j)*ne01 + x]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { + if (ty + j < ne01 && tx < ne00) { + const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T); + const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]); + dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; + } + } + } + + GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, + nb12, nb13); +} + +static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *)(cdsti); + +#pragma unroll + for (int j = 0; j < QK8_0; j += 2) { + float2 dq; + dequantize_q8_0(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + 1) = dq.y; + } +} + +template<dequantize_kernel_t dequant, int qk> +static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { + float * cdstf = (float *)(cdsti); + +#pragma unroll + for (int j = 0; j < qk/2; j++) { + float2 dq; + dequant(cxi, 0, j, dq); + *(cdstf + j) = dq.x; + *(cdstf + j + qk/2) = dq.y; + } +} + +template <cpy_kernel_t cpy_blck, int qk> +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; + + if (i >= ne) { + return; + } + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + +template <cpy_kernel_t cpy_blck, int qk> +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; + + if (i >= ne) { + return; + } + + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + + cpy_blck(cx + x_offset, cdst + dst_offset); +} + +template<typename src_t, typename dst_t> +static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const src_t * x = (const src_t *) cx; + dst_t * dst = (dst_t *) cdst; + + dst[i] = ggml_cuda_cast<dst_t>(x[i]); +} + +template<typename src_t, typename dst_t> +static void ggml_cpy_scalar_contiguous_cuda( + const char * cx, char * cdst, const int64_t ne, +cudaStream_t stream) { + + const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> + (cx, cdst, ne); +} + +template<typename src_t, typename dst_t, bool transposed = false> +static void ggml_cpy_scalar_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + if (transposed) { + GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed + int64_t ne00n, ne01n, ne02n; + if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here + ne00n = ne00; + ne01n = ne01; + ne02n = ne02; + } else { + ne00n = ne00; + ne01n = ne01*ne02; + ne02n = 1; + } + + int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM; + GGML_ASSERT(grid_x < UINT_MAX); + GGML_ASSERT(grid_y < USHRT_MAX); + GGML_ASSERT(grid_z < USHRT_MAX); + dim3 dimGrid(grid_x, grid_y, grid_z); + dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); + cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>> + (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } else { + const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + } +} + +static void ggml_cpy_f32_q8_0_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK8_0 == 0); + const int64_t num_blocks = ne / QK8_0; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q8_0_f32_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_f32_q4_0_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_0 == 0); + const int64_t num_blocks = ne / QK4_0; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q4_0_f32_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, + cudaStream_t stream) { + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_f32_q4_1_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_1 == 0); + const int64_t num_blocks = ne / QK4_1; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q4_1_f32_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, + cudaStream_t stream) { + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_f32_q5_0_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK5_0 == 0); + const int64_t num_blocks = ne / QK5_0; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q5_0_f32_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, + cudaStream_t stream) { + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_f32_q5_1_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK5_1 == 0); + const int64_t num_blocks = ne / QK5_1; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_q5_1_f32_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, + cudaStream_t stream) { + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>( + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, + ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +static void ggml_cpy_f32_iq4_nl_cuda( + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { + + GGML_ASSERT(ne % QK4_NL == 0); + const int64_t num_blocks = ne / QK4_NL; + GGML_ASSERT(num_blocks < UINT_MAX); + cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); +} + +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) { + const int64_t ne = ggml_nelements(src0); + GGML_ASSERT(ne == ggml_nelements(src1)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + //GGML_ASSERT(src0->ne[3] == 1); + + const int64_t nb00 = src0->nb[0]; + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; + const int64_t nb03 = src0->nb[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + + //GGML_ASSERT(src1->ne[3] == 1); + + const int64_t nb10 = src1->nb[0]; + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; + + cudaStream_t main_stream = ctx.stream(); + + char * src0_ddc = (char *) src0->data; + char * src1_ddc = (char *) src1->data; + + const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && + src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0); + + if (src0->type == src1->type && contiguous_srcs) { + GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { + CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0)); + } else +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY + { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + if (can_be_transposed) { + ggml_cpy_scalar_cuda<float, float, true> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_scalar_cuda<float, float> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<float, nv_bfloat16> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<float, half> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<float, half> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { + ggml_cpy_f32_q8_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q8_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) { + ggml_cpy_f32_q4_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) { + ggml_cpy_f32_q4_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q4_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) { + ggml_cpy_f32_q5_0_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_0_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) { + ggml_cpy_f32_iq4_nl_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) { + ggml_cpy_f32_q5_1_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { + ggml_cpy_q5_1_f32_cuda + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + if (can_be_transposed) { + ggml_cpy_scalar_cuda<half, half, true> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_scalar_cuda<half, half> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<half, nv_bfloat16> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<half, float> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<half, float> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + if (can_be_transposed) { + ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<nv_bfloat16, half> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<nv_bfloat16, float> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + if (can_be_transposed) { + ggml_cpy_scalar_cuda<int32_t, int32_t, true> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } else { + ggml_cpy_scalar_cuda<int32_t, int32_t> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<float, int32_t> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<float, int32_t> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + if (contiguous_srcs) { + ggml_cpy_scalar_contiguous_cuda<int32_t, float> + (src0_ddc, src1_ddc, ne, main_stream); + } else { + ggml_cpy_scalar_cuda<int32_t, float> + (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream); + } + } else { + GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, + ggml_type_name(src0->type), ggml_type_name(src1->type)); + } +} + +void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + ggml_cuda_cpy(ctx, src0, dst); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/cpy.cuh b/llama.cpp/ggml/src/ggml-cuda/cpy.cuh new file mode 100644 index 0000000..a7a87d8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cpy.cuh @@ -0,0 +1,7 @@ +#include "common.cuh" + +#define CUDA_CPY_BLOCK_SIZE 64 + +void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1); + +void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu new file mode 100644 index 0000000..0c8b081 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -0,0 +1,177 @@ +#include "common.cuh" +#include "cross-entropy-loss.cuh" +#include "sum.cuh" + +#include <cmath> +#include <cstdint> + +template <bool use_shared> +static __global__ void cross_entropy_loss_f32( + const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) { + extern __shared__ float tmp[]; + + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; + + // Find maximum for softmax: + float max_logit = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[i]; + max_logit = fmaxf(max_logit, val); + + if (use_shared) { + tmp[i] = val; + } + } + max_logit = warp_reduce_max(max_logit); + + // Calculate log(softmax(logits)) which is just logits - max: + float sum = 0.0f; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + sum += expf(logit_i - max_logit); + } + sum = warp_reduce_sum(sum); + sum = logf(sum); + + // log(exp(logits - max) / sum) = (logits - max) - log(sum) + float loss = 0.0f; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float logit_i = use_shared ? tmp[i] : logits[i]; + loss += (logit_i - max_logit - sum) * labels[i]; + } + loss = -warp_reduce_sum(loss) / (float)k; + + if (threadIdx.x != 0) { + return; + } + + dst[blockIdx.x] = loss; +} + +template <bool use_shared> +static __global__ void cross_entropy_loss_back_f32( + const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels, + float * __restrict__ dst, const int nclasses) { + extern __shared__ float tmp[]; + + logits += int64_t(blockIdx.x)*nclasses; + labels += int64_t(blockIdx.x)*nclasses; + dst += int64_t(blockIdx.x)*nclasses; + + float maxval = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[i]; + maxval = fmaxf(maxval, val); + + if (use_shared) { + tmp[i] = val; + } + } + maxval = warp_reduce_max(maxval); + + float sum = 0.0f; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval); + sum += val; + + if (use_shared) { + tmp[i] = val; + } else { + dst[i] = val; + } + } + sum = warp_reduce_sum(sum); + const float sm_scale = 1.0f/sum; + + const float d_by_nrows = *grad/gridDim.x; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = use_shared ? tmp[i] : dst[i]; + dst[i] = (val*sm_scale - labels[i])*d_by_nrows; + } +} + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t stream = ctx.stream(); + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x); + + if (nbytes_shared <= smpbo) { + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo); + cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } else { + cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows); + } + CUDA_CHECK(cudaGetLastError()); + + // Combine results from individual blocks: + sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); +} + +void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; + const ggml_tensor * src0f = dst->src[1]; + const ggml_tensor * src1f = dst->src[2]; + + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT(src1f->type == GGML_TYPE_F32); + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_scalar(grad)); + GGML_ASSERT(ggml_is_contiguous(src0f)); + GGML_ASSERT(ggml_is_contiguous(src1f)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0f, src1f)); + GGML_ASSERT(ggml_are_same_shape(src0f, dst)); + + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); + + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + const float * src1f_d = (const float *) src1f->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const size_t nbytes_shared = ne00*sizeof(float); + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + if (nbytes_shared <= smpbo) { + CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo); + cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } else { + cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cuh new file mode 100644 index 0000000..9ec7152 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cuh @@ -0,0 +1,7 @@ +#include "common.cuh" + +#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256 + +void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/cumsum.cu b/llama.cpp/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 0000000..def9c32 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,307 @@ +#include <algorithm> +#include "cumsum.cuh" +#include "convert.cuh" +#include "ggml-cuda/common.cuh" +#include "ggml.h" + +#ifdef GGML_CUDA_USE_CUB +# include <cub/cub.cuh> +#endif // GGML_CUDA_USE_CUB + +template<typename T, int BLOCK_SIZE> +static __global__ void cumsum_cub_kernel( + const T * __restrict__ src, + T * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s1, const int64_t s2, const int64_t s3) { +#ifdef GGML_CUDA_USE_CUB + using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>; + + __shared__ typename BlockScanT::TempStorage temp_storage; + __shared__ T block_carry; + + const int tid = threadIdx.x; + constexpr int UNROLL_FACTOR = 4; + constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR; + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + const int64_t i3 = blockIdx.z; + + if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { + return; + } + + const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; + + if (tid == 0) { + block_carry = 0; + } + __syncthreads(); + + for (int64_t start = 0; start < ne00; start += TILE_SIZE) { + T items[UNROLL_FACTOR]; + T thread_sum = T(0); + +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + T val = (idx < ne00) ? src_row[idx] : T(0); + thread_sum += val; + items[i] = thread_sum; + } + + // Block-wide scan on thread sums + T thread_prefix; + T block_total; + BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total); + __syncthreads(); + + // Add offset to each item and store + T thread_offset = thread_prefix - thread_sum + block_carry; +#pragma unroll + for (int i = 0; i < UNROLL_FACTOR; i++) { + int64_t idx = start + tid * UNROLL_FACTOR + i; + if (idx < ne00) { + dst_row[idx] = items[i] + thread_offset; + } + } + + __syncthreads(); + + // Update carry for next tile + if (tid == 0) { + block_carry += block_total; + } + } +#else + NO_DEVICE_CODE; +#endif // GGML_CUDA_USE_CUB +} + +// Fallback kernel implementation +template<typename T> +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) { + + GGML_UNUSED_VARS(s00, s0); + + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int lane = tid % warp_size; + const int warp = tid / warp_size; + const int warps_per_block = blockDim.x / warp_size; + + extern __shared__ float smem[]; + float * s_vals = smem; + float * s_warp_sums = smem + blockDim.x; + float * s_carry = smem + blockDim.x + warps_per_block; + float * s_chunk_total = s_carry + 1; + + // Initialize carry + if (tid == 0) { + *s_carry = 0.0f; + } + __syncthreads(); + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3; + + // register blocking: process 4 elements per thread to hide latency + // and reduce synchronization overhead + constexpr int num_unroll = 4; + T temp[num_unroll]; + + for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) { + int64_t idx = i + tid * num_unroll; + + // thread local sequential scan + temp[0] = (idx < ne00 ? src_row[idx] : T(0)); +#pragma unroll + for (int64_t j = 1; j < num_unroll; j++) { + temp[j] = temp[j - 1]; + if (idx + j < ne00) { + temp[j] += src_row[idx + j]; + } else { + temp[j] += 0; + } + } + + // last emenent is sum of all values assigned to thread + float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f; + + // Warp inclusive scan + val = warp_prefix_inclusive_sum<T, warp_size>(val); + s_vals[tid] = val; + + if (lane == warp_size - 1) { + s_warp_sums[warp] = val; + } + __syncthreads(); + + // Exclusive scan of warp sums (warp 0 only) + if (warp == 0) { + float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; + float inc = warp_prefix_inclusive_sum<T, warp_size>(w); + if (tid < warps_per_block) { + s_warp_sums[tid] = inc - w; // exclusive sum + } + if (tid == warps_per_block - 1) { + *s_chunk_total = inc; // total sum of this chunk + } + } + __syncthreads(); + + // write back results + float carry = *s_carry; + // calculate sum offset for this thread + float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; + +#pragma unroll + for (int32_t j = 0; j < num_unroll; j++) { + if (idx + j < ne00) { + dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset); + } + } + + __syncthreads(); + + // Update carry for next chunk + if (tid == 0) { + *s_carry += *s_chunk_total; + } + } +} + +#ifdef GGML_CUDA_USE_CUB +template <typename T> +static void cumsum_cub(ggml_cuda_pool & pool, + const T * src, + T * dst, + int64_t ne, + cudaStream_t stream) { + size_t tmp_size = 0; + + // Query how much temp storage CUDA UnBound (CUB) needs + cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size) + tmp_size, // reference to size (will be set by CUB) + src, // input pointer + dst, // output pointer + ne, // number of elements + stream // CUDA stream to use + ); + + ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size); + + // Perform the inclusive scan + cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream); +} +#endif // GGML_CUDA_USE_CUB + +template<typename T> +static void cumsum_cuda( + [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + const size_t type_size = sizeof(T); + bool use_cub = false; +#ifdef GGML_CUDA_USE_CUB + // Check if we can use CUB (data must be contiguous along innermost dimension) + const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size); + + if (is_contiguous) { + use_cub = true; + const int64_t nrows = ne01 * ne02 * ne03; + // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released + // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004 + if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) { + for (int i=0; i<nrows; i++) { + cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream); + } + return; + } + } +#endif // GGML_CUDA_USE_CUB + dim3 grid_dims(ne01, ne02, ne03); + const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()]; + const int warp_size = info.warp_size; + const int num_warps = (ne00 + warp_size - 1) / warp_size; + int block_size = num_warps * warp_size; + block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); + dim3 block_dims(block_size, 1, 1); + const int warps_per_block = block_size / warp_size; + const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); + + if (use_cub && ne00 >= 1024) { + cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>( + src, dst, + ne00, ne01, ne02, ne03, + nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + ctx, (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms + /*case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break;*/ + default: + GGML_ABORT("fatal error"); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/cumsum.cuh b/llama.cpp/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 0000000..782d1d9 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh b/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh new file mode 100644 index 0000000..e060fb2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh @@ -0,0 +1,77 @@ +#include "common.cuh" + +static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const float d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + + v.x = (v.x - 8.0f) * d; + v.y = (v.y - 8.0f) * d; +} + +static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const float2 dm = __half22float2(x[ib].dm); + + const int vui = x[ib].qs[iqs]; + + v.x = vui & 0xF; + v.y = vui >> 4; + + v.x = (v.x * dm.x) + dm.y; + v.y = (v.y * dm.x) + dm.y; +} + +static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const float d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + + v.x = (v.x - 16.0f) * d; + v.y = (v.y - 16.0f) * d; +} + +static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const float2 dm = __half22float2(x[ib].dm); + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + v.x = ((x[ib].qs[iqs] & 0xf) | xh_0); + v.y = ((x[ib].qs[iqs] >> 4) | xh_1); + + v.x = (v.x * dm.x) + dm.y; + v.y = (v.y * dm.x) + dm.y; +} + +static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const float d = x[ib].d; + + v.x = x[ib].qs[iqs + 0]; + v.y = x[ib].qs[iqs + 1]; + + v.x *= d; + v.y *= d; +} diff --git a/llama.cpp/ggml/src/ggml-cuda/diag.cu b/llama.cpp/ggml/src/ggml-cuda/diag.cu new file mode 100644 index 0000000..5cea210 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/diag.cu @@ -0,0 +1,77 @@ +#include "convert.cuh" +#include "diag.cuh" +#include "ggml.h" + +template <typename T> +static __global__ void diag_kernel(T * __restrict__ dst, + const T * __restrict__ src, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const int64_t total_elements) { + const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (global_idx >= total_elements) { + return; + } + + const int64_t i0 = global_idx % ne0; + const int64_t i1 = (global_idx / ne0) % ne1; + const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2; + const int64_t i3 = global_idx / (ne0 * ne1 * ne2); + + const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; + + if (i0 == i1) { + const int64_t batch_idx = i3 * ne2 + i2; + const int64_t src_idx = batch_idx * ne0 + i0; + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = ggml_cuda_cast<T>(0); + } + GGML_UNUSED_VARS(ne3); +} + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + void * dst_d = dst->data; + const void * src0_d = src0->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + GGML_ASSERT(ne00 == ne0); + GGML_ASSERT(ne01 == 1); + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne3); + + const int64_t n_elems = ggml_nelements(dst); + const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + case GGML_TYPE_F16: + diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0, + ne1, ne2, ne3, n_elems); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/diag.cuh b/llama.cpp/ggml/src/ggml-cuda/diag.cuh new file mode 100644 index 0000000..7d73e6a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/diag.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_DIAG_BLOCK_SIZE 256 + +void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/diagmask.cu b/llama.cpp/ggml/src/ggml-cuda/diagmask.cu new file mode 100644 index 0000000..4b713ba --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/diagmask.cu @@ -0,0 +1,40 @@ +#include "diagmask.cuh" + +static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { + const int col = blockDim.y*blockIdx.y + threadIdx.y; + const int row = blockDim.x*blockIdx.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int i = row*ncols + col; + //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i]; + //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU + dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX; +} + +static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { + const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); + const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; + const dim3 block_nums(nrows_x, block_num_x, 1); + diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past); +} + +void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int nrows0 = ggml_nrows(src0); + + const int n_past = ((int32_t *) dst->op_params)[0]; + + diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/diagmask.cuh b/llama.cpp/ggml/src/ggml-cuda/diagmask.cuh new file mode 100644 index 0000000..6cdbef1 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/diagmask.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 + +void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh new file mode 100644 index 0000000..b6a7460 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh @@ -0,0 +1,1036 @@ +#pragma once + +#include "common.cuh" +#include "convert.cuh" +#include "vecdotq.cuh" + +#include <cstdint> + +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. + +// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable +// by the VKQ accumulators is effectively being shifted up by a factor of 2. +// This reduces issues with numerical overflow but also causes larger values to be flushed to zero. +// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible. +// Still, the value range should be shifted as much as necessary but as little as possible. +// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 . +#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f) + +typedef void (* fattn_kernel_t)( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33); + +typedef float (*vec_dot_KQ_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); + +template <int D, int nthreads> +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const half2 * K_h2 = (const half2 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) half2 tmp[cpy_ne]; + ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#else + ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + +template<int D, int nthreads> +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y); + } + + return sum; +} + +template<int D, int nthreads> +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 K_dm = __half22float2(K_q4_1[ib].dm); + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1; + } + + return sum; +} + +template<int D, int nthreads> +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y); + } + + return sum; +} + +template<int D, int nthreads> +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_cuda_dp4a(v, u, 0); + + const float2 K_dm = __half22float2(K_q5_1[ib].dm); + const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads]; + + sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1; + } + + return sum; +} + +template <int D, int nthreads> +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads); + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + int v; + ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs); + + const float2 * Q_ds = (const float2 *) Q_ds_v; + const float Q_d = Q_ds[k_KQ_0/nthreads].x; + + sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d); + } + + return sum; +} + +template <typename Tds, int ni> +static __device__ __forceinline__ void quantize_q8_1_to_shared( + const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { + + float vals[sizeof(int)] = {0.0f}; +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f; + } + + float amax = fabsf(vals[0]); + float sum = vals[0]; +#pragma unroll + for (int l = 1; l < int(sizeof(int)); ++l) { + amax = fmaxf(amax, fabsf(vals[l])); + sum += vals[l]; + } +#pragma unroll + for (int mask = QI8_1/2; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32)); + sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32); + } + + const float d = amax / 127; + int q32 = 0; + int8_t * q8 = (int8_t *) &q32; + + if (d != 0.0f) { +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + q8[l] = roundf(vals[l] / d); + } + } + + yq32[threadIdx.x] = q32; + if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) { + if (std::is_same<Tds, half2>::value) { + ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum); + } else { + ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum); + } + } +} + +typedef void (*dequantize_V_t)(const void *, void *, const int64_t); + +template <typename T, int ne> +static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + if constexpr (std::is_same_v<T, half>) { + ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0); + } else if constexpr (std::is_same_v<T, float>) { + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) half2 tmp[ne/2]; + ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = __half22float2(tmp[l]); + } + } else { + static_assert(std::is_same_v<T, void>, "unsupported type"); + } +} + +template <typename T, int ne> +static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int64_t ib = i0 / QK4_0; + const int iqs = i0 % (QK4_0/2); + const int shift = (i0 % QK4_0) / (QK4_0/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + q = __vsubss4(q, 0x08080808); + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v<T, half>) { + const half2 d = __half2half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v<T, float>) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int64_t ib = i0 / QK4_1; + const int iqs = i0 % (QK4_1/2); + const int shift = (i0 % QK4_1) / (QK4_1/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v<T, half>) { + const half2 dm = x[ib].dm; + const half2 d = __half2half2( __low2half(dm)); + const half2 m = __half2half2(__high2half(dm)); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v<T, float>) { + const float2 dm = __half22float2(x[ib].dm); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x * q8[l] + dm.y; + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int64_t ib = i0 / QK5_0; + const int idq = i0 % QK5_0; + const int iqs = i0 % (QK5_0/2); + const int shift = (i0 % QK5_0) / (QK5_0/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + q = __vsubss4(q, 0x10101010); + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v<T, half>) { + const half2 d = __half2half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v<T, float>) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int64_t ib = i0 / QK5_1; + const int idq = i0 % QK5_1; + const int iqs = i0 % (QK5_1/2); + const int shift = (i0 % QK5_1) / (QK5_1/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v<T, half>) { + const half2 dm = x[ib].dm; + const half2 d = __half2half2( __low2half(dm)); + const half2 m = __half2half2(__high2half(dm)); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same_v<T, float>) { + const float2 dm = __half22float2(x[ib].dm); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x * q8[l] + dm.y; + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int64_t ib = i0 / QK8_0; + const int iqs = i0 % QK8_0; + + static_assert(ne % 2 == 0, "bad ne"); + int8_t qs[ne]; + ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs); + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same<T, half>::value) { + const half2 d = __half2half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]); + } + } else +#endif // FP16_AVAILABLE + if constexpr (std::is_same<T, float>::value) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * qs[l]; + } + } else { + static_assert(std::is_same_v<T, void>, "unsupported type"); + } +} + +template <ggml_type type_K, int D, int nthreads> +constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { + if constexpr (type_K == GGML_TYPE_F16) { + return vec_dot_fattn_vec_KQ_f16<D, nthreads>; + } else if constexpr (type_K == GGML_TYPE_Q4_0) { + return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>; + } else if constexpr (type_K == GGML_TYPE_Q4_1) { + return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>; + } else if constexpr (type_K == GGML_TYPE_Q5_0) { + return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>; + } else if constexpr (type_K == GGML_TYPE_Q5_1) { + return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>; + } else if constexpr (type_K == GGML_TYPE_Q8_0) { + return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>; + } else { + static_assert(type_K == -1, "bad type"); + return nullptr; + } +} + +template <ggml_type type_V, typename T, int ne> +constexpr __device__ dequantize_V_t get_dequantize_V() { + if constexpr (type_V == GGML_TYPE_F16) { + return dequantize_V_f16<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q4_0) { + return dequantize_V_q4_0<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q4_1) { + return dequantize_V_q4_1<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q5_0) { + return dequantize_V_q5_0<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q5_1) { + return dequantize_V_q5_1<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q8_0) { + return dequantize_V_q8_0<T, ne>; + } else { + static_assert(type_V == -1, "bad type"); + return nullptr; + } +} + +template <int ncols1> +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) +static __global__ void flash_attn_mask_to_KV_max( + const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) { + const int ne31 = gridDim.x; + const int tid = threadIdx.x; + const int sequence = blockIdx.y; + const int jt = blockIdx.x; + + mask += sequence*s33 + jt*ncols1*s31; + + __shared__ int buf_iw[WARP_SIZE]; + if (tid < WARP_SIZE) { + buf_iw[tid] = 1; + } + __syncthreads(); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % WARP_SIZE == 0) { + buf_iw[tid / WARP_SIZE] = all_inf; + } + __syncthreads(); + all_inf = buf_iw[tid % WARP_SIZE]; + __syncthreads(); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE. + // If the break was triggered it's the lower edge of the tile with the first non-masked values. + // In either case, walk back the decrementation by FATTN_KQ_STRIDE. + KV_max_sj += FATTN_KQ_STRIDE; + + if (threadIdx.x != 0) { + return; + } + + KV_max[sequence*ne31 + jt] = KV_max_sj; +} + +template<int D, int ncols1, int ncols2> // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup( + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, + const int ne11, const int ne12, const int nbatch_fa) { + constexpr int ncols = ncols1*ncols2; + + const int bidx0 = blockIdx.x; + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const float2 tmp = dst_fixup[bidx0*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Iterate over previous blocks and compute the combined results. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template<int D> // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_combine_results( + const float * __restrict__ VKQ_parts, + const float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks) { + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + const int ne01 = gridDim.x; + const int ne02 = gridDim.y; + + const int col = blockIdx.x; + const int head = blockIdx.y; + const int sequence = blockIdx.z; + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; + + const int tid = threadIdx.x; + __builtin_assume(tid < D); + + extern __shared__ float2 meta[]; + for (int i = tid; i < 2*parallel_blocks; i += D) { + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; + } + + __syncthreads(); + + float kqmax = meta[0].x; + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = max(kqmax, meta[l].x); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float KQ_max_scale = expf(meta[l].x - kqmax); + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y; + } + + dst[tid] = VKQ_numerator / VKQ_denominator; +} + +template <int DV, int ncols1, int ncols2> +void launch_fattn( + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE +) { + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t main_stream = ctx.stream(); + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + + ggml_cuda_pool_alloc<half> K_f16(pool); + ggml_cuda_pool_alloc<half> V_f16(pool); + ggml_cuda_pool_alloc<int> KV_max(pool); + ggml_cuda_pool_alloc<float> dst_tmp(pool); + ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + K_f16.alloc(ggml_nelements(K)); + if (ggml_is_contiguously_allocated(K)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(K->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type); + const int64_t s01 = nb11 / ts; + const int64_t s02 = nb12 / ts; + const int64_t s03 = nb13 / ts; + to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + + nb11 = K->ne[0] * sizeof(half); + nb12 = K->ne[1] * nb11; + nb13 = K->ne[2] * nb12; + } + K_data = (char *) K_f16.ptr; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; + } else { + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; + } + } + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; + + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or + // multiple sequences of possibly different lengths. + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + const int s31 = mask->nb[1] / sizeof(half2); + const int s33 = mask->nb[3] / sizeof(half2); + + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); + + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + + KV_max.alloc(ne_KV_max); + flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>> + ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33); + CUDA_CHECK(cudaGetLastError()); + } + + const dim3 block_dim(warp_size, nwarps, 1); + int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. + CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); + GGML_ASSERT(max_blocks_per_sm > 0); + int parallel_blocks = max_blocks_per_sm; + + dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = max_blocks_per_sm*nsm; + const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + + const int nblocks_stream_k = max_blocks; + + const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); + } + } else { + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 95 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // TODO other tensor dimensions after removal of WMMA kernel: + const uint3 ne01 = init_fastdiv_values(Q->ne[1]); + + GGML_ASSERT(block_dim.x % warp_size == 0); + fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( + (const char *) Q->data, + K_data, + V_data, + mask ? ((const char *) mask->data) : nullptr, + sinks ? ((const char *) sinks->data) : nullptr, + KV_max.ptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, + scale, max_bias, m0, m1, n_head_log2, logit_softcap, + Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, + nb21, nb22, nb23, + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0 + ); + CUDA_CHECK(cudaGetLastError()); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; + + flash_attn_stream_k_fixup<DV, ncols1, ncols2> + <<<blocks_num_combine, block_dim_combine, 0, main_stream>>> + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); + } + } else if (parallel_blocks > 1) { + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); + const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); + + flash_attn_combine_results<DV> + <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>> + (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + } + CUDA_CHECK(cudaGetLastError()); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh new file mode 100644 index 0000000..0b8ef90 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -0,0 +1,1750 @@ +#include "common.cuh" +#include "cp-async.cuh" +#include "mma.cuh" +#include "fattn-common.cuh" + +using namespace ggml_cuda_mma; + +// Config options for the MMA kernel. +// Should not affect results, only speed/register pressure/shared memory use. +struct fattn_mma_config { + int nthreads; // Number of threads per CUDA block. + int occupancy; // Targeted occupancy for the MMA kernel. + int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators. + int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel. + int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel. + int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel. + int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support. + bool Q_in_reg; // Whether the Q values should be kept permanently in registers. + + constexpr __host__ __device__ fattn_mma_config( + int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) : + nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine), + nstages_target(nstages_target), Q_in_reg(Q_in_reg) {} +}; + +#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \ + static_assert( (occupancy_) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \ + static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \ + static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \ + static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \ + return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \ + } \ + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false); + + // TODO tune specifically for Volta + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + // TODO tune specifically for RDNA + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + +static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (ampere_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); + } + if (turing_mma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); + } + if (amd_wmma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); + } + GGML_ASSERT(volta_mma_available(cc)); + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +} + +static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) { +#if defined(AMPERE_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +#elif defined(TURING_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(VOLTA_MMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +#elif defined(AMD_WMMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); +#else + GGML_UNUSED_VARS(DKQ, DV, ncols); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +#endif // defined(AMPERE_MMA_AVAILABLE) +} + +static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads; +} + +static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2; +} + +static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine; +} + +static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target; +} + +static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg; +} + +static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) { + return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; +} + +static constexpr __device__ int get_cols_per_thread() { +#if defined(AMD_WMMA_AVAILABLE) + return 1; // RDNA has a single column. +#else + return 2; // This is specifically KQ columns, Volta only has a single VKQ column. +#endif // defined(AMD_WMMA_AVAILABLE) +} + +static __host__ int get_cols_per_warp(const int cc) { + if (turing_mma_available(cc) || amd_wmma_available(cc)) { + return 16; + } else { + // Volta + return 32; + } +} + +// ------------------------------------------------------------------------------------------------------------------ + +static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { + return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0; +} + +static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) { +#ifdef CP_ASYNC_AVAILABLE + return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0; +#else + GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2); + return 0; +#endif // CP_ASYNC_AVAILABLE +} + +// ------------------------------------------------------------------------------------------------------------------ + +template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check> +static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { + // K/V data is loaded with decreasing granularity for D for better memory bandwidth. + // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + if constexpr (use_cp_async) { + static_assert(!oob_check, "OOB check not compatible with cp_async"); + constexpr int preload = 64; + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; + + const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); + + auto load = [&] __device__ (auto n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); + } + } + }; + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); + } else { + // TODO use ggml_cuda_memcpy_1 + auto load = [&] __device__ (const int n) { + const int stride_k = WARP_SIZE >> n; + const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_i = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { + break; + } + +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + } + } + }; + // 1: max 32* 4=128 bytes, 64 half + // 2: max 16* 4= 64 bytes, 32 half + // 3: max 8* 4= 32 bytes, 16 half + // 4: max 4* 4= 16 bytes, 8 half + ggml_cuda_unroll<4>{}(load); + } +} + +template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check> +static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( + const half * const __restrict__ mask_h, half * const __restrict__ tile_mask, + const int stride_mask, const int i_sup, const int j0, const uint3 ne01) { + if constexpr (use_cp_async) { + static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa"); + static_assert(!oob_check, "OOB check incompatible with cp_async"); + constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; + constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; + + const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); + +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { + break; + } + + const int i = 8 * (threadIdx.x % (nbatch_fa/8)); + + cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); + } + } else if constexpr (oob_check) { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + } + } + } else if constexpr (nbatch_fa < 2*WARP_SIZE) { + constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + constexpr int stride_j = nwarps * cols_per_warp; +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += stride_j) { + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + stride_j > ncols1 && j_sram >= ncols1) { + break; + } + + const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + + ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + } + } else { +#pragma unroll + for (int j1 = 0; j1 < ncols1; j1 += nwarps) { + const int j_sram = j1 + threadIdx.y; + const int j_vram = fastmodulo(j0 + j_sram, ne01); + + if (j1 + nwarps > ncols1 && j_sram >= ncols1) { + break; + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + const int i = i0 + 2*threadIdx.x; + + ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + } + } + } +} + +template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, + bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check, + typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ> +static __device__ __forceinline__ void flash_attn_ext_f16_iter( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask_h, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const uint3 ne01, + const int ne02, + const int stride_K, + const int stride_V, + const int stride_mask, + half2 * const __restrict__ tile_Q, + half2 * const __restrict__ tile_K, + half2 * const __restrict__ tile_V, + half * const __restrict__ tile_mask, + T_B_KQ * const __restrict__ Q_B, + T_C_VKQ * const __restrict__ VKQ_C, + float * const __restrict__ KQ_max, + float * const __restrict__ KQ_rowsum, + const int jt, + const int kb0, + const int k_VKQ_sup) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + constexpr int ncols = ncols1 * ncols2; + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = get_cols_per_thread(); + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); + + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; + + const int k_VKQ_0 = kb0 * nbatch_fa; +#if defined(TURING_MMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#elif defined(AMD_WMMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; +#else // Volta + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) + + if constexpr (nstages > 1) { + static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check> + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup); + } else { + constexpr bool use_cp_async = nstages == 1; + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check> + (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); + } + } + + // For MLA K and V have the same data. + // Therefore, iterate over K in reverse and later re-use the data if possible. +#pragma unroll + for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { + const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; + const int k0_diff = k0_stop - k0_start; + + if constexpr (nstages <= 1) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check> + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + + // Calculate tile of KQ: + if constexpr (Q_in_reg) { +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + T_A_KQ K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); + } else { + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); +#else + // swap A and B for CUDA. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) + } + } + } + } else { +#pragma unroll + for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { + load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); + +#pragma unroll + for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) { + const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I; + + T_A_KQ K_A; + load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); + + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + } else { + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); +#else + // swap A and B for CUDA. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) + } + } + } + } + + if constexpr (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } + + if (use_logit_softcap) { + constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J; + static_assert(nbatch_fa % stride == 0, "bad loop size"); +#pragma unroll + for (int i = 0; i < nbatch_fa/stride; ++i) { +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); + } + } + } + + float KQ_max_new[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max_new[col] = KQ_max[col]; + } + float KQ_rowsum_add[cols_per_thread] = {0.0f}; + + if constexpr (cols_per_warp == 8) { + if (ncols2 > 1 || mask_h) { +#pragma unroll + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I; +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_i(l); + const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2; + + KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]); + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); + } + } + } + + // Values per KQ column are spread across 8 threads: +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = 16; offset >= 4; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; + } else { + KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; + } + } + } + } else { // not Turing mma or T_B_KQ::I > 8 + if (ncols2 > 1 || mask_h) { +#pragma unroll + for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { + const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; +#pragma unroll + for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { + const int i = (i0 + T_C_KQ::get_j(l0)) / 2; + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2; + + const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]); + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; + KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; + } + } + } + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); + } + } + } + +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#if defined(TURING_MMA_AVAILABLE) + // Values per KQ column are spread across 4 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 1; +#elif defined(AMD_WMMA_AVAILABLE) + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 16; + constexpr int offset_last = 16; +#else // Volta + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + } + } + + static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; + } else { + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; + } + } + } + } + + { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const float KQ_max_diff = KQ_max[col] - KQ_max_new[col]; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new[col]; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col]; + } + +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#else // Volta + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) + } + + // Convert KQ C tiles into B tiles for VKQ calculation: + T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)]; + static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size"); + if constexpr (cols_per_warp == 8) { +#pragma unroll + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { + B[k] = get_transposed(get_half2(KQ_C[k])); + } + } else { + for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) { + B[k] = get_half2(KQ_C[k]); + } + } + + if constexpr (nstages > 1) { + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); + // Preload K tile for next iteration: + constexpr bool use_cp_async = true; + cp_async_wait_all(); + __syncthreads(); + if (!last_iter) { + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check> + (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); + } + flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check> + (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); + } + } + + +#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) + T_A_VKQ A_identity; + make_identity_mat(A_identity); +#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) + + // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: +#pragma unroll + for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { + static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); + const int i0_stop = i0_start + 2*nbatch_V2; + const int i0_diff = i0_stop - i0_start; + + if constexpr (nstages <= 1) { + if (!V_is_K_view || i0_stop > 2*nbatch_K2) { + constexpr bool use_cp_async = nstages == 1; + flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check> + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); + if (use_cp_async) { + cp_async_wait_all(); + } + __syncthreads(); + } + } + const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; + +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); +#pragma unroll + for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; + + T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. +#if defined(LDMATRIX_TRANS_AVAILABLE) + load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); +#else + // TODO: Try to transpose tile_V when loading gmem to smem. + // Use mma to transpose T_A_VKQ for RDNA. + T_A_VKQ A_trans; + load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(A, A_trans, A_identity); +#endif // defined(TURING_MMA_AVAILABLE) + if constexpr (T_B_KQ::I == 8) { + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + } else { + // Wide version of VKQ_C is column-major. +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); +#else + // swap A and B for CUDA. + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); +#endif // defined(AMD_WMMA_AVAILABLE) + } + } + } +#else // Volta + constexpr int i0_stride = 2*T_C_VKQ::J; +#pragma unroll + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size"); + static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes"); +#pragma unroll + for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) { + const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I; + + T_A_VKQ A; // Transposed in both SRAM and registers, load normally. + load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); + } + } +#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + if constexpr (nstages <= 1) { + __syncthreads(); // Only needed if tile_K == tile_V. + } + } +#else + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, + scale, slope, logit_softcap, ne01, ne02, + stride_K, stride_V, stride_mask, + tile_Q, tile_K, tile_V, tile_mask, + Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + NO_DEVICE_CODE; +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +} + +#if defined(TURING_MMA_AVAILABLE) +template<int ncols> struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template<> struct mma_tile_sizes<8> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile< 8, 8, half2>; // column-major + using T_C_KQ = tile<16, 8, float>; // row-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile< 8, 8, half2>; // column-major + using T_C_VKQ = tile<16, 4, half2>; // row-major +}; +#elif defined(AMD_WMMA_AVAILABLE) +template<int ncols> struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +#else // Volta +template<int ncols> struct mma_tile_sizes { + using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major + using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#endif // defined(TURING_MMA_AVAILABLE) + +template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup> +static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( + const float2 * const __restrict__ Q_f2, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask_h, + const float * const __restrict__ sinks_f, + float2 * const __restrict__ dstk, + float2 * const __restrict__ dstk_fixup, + const float scale, + const float slope, + const float logit_softcap, + const uint3 ne01, + const int ne02, + const int gqa_ratio, + const int ne11, + const int stride_Q1, + const int stride_Q2, + const int stride_K, + const int stride_V, + const int stride_mask, + const int jt, + const int zt_gqa, + const int kb0_start, + const int kb0_stop) { +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int ncols = ncols1 * ncols2; + using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ; + + constexpr int cols_per_warp = T_B_KQ::I; + constexpr int cols_per_thread = get_cols_per_thread(); + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); + constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); + constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); + constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols); + constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); + constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); + + if (cols_per_warp > ncols) { + NO_DEVICE_CODE; + return; + } + + static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps"); + + constexpr int stride_tile_Q = DKQ/2 + 4; + constexpr int stride_tile_K = nbatch_K2 + 4; + + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; + + extern __shared__ half2 tile_Q[]; + half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q; + half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K; + half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max); + + T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; +#if defined(TURING_MMA_AVAILABLE) + T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; +#else // Volta + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; +#endif // defined(TURING_MMA_AVAILABLE) + + float KQ_rowsum[cols_per_thread] = {0.0f}; + float KQ_max[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + KQ_max[col] = -FLT_MAX/2.0f; + } + + // Load Q data into tile_Q, either temporarily or permanently. + // Q in registers is faster, but register pressure is the biggest bottleneck. + // The loading is done with decreasing granularity for D for better memory bandwidth. + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { + break; + } + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; + tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); + } + } else { +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); + } + } + } + } + + __syncthreads(); + + if (Q_in_reg) { + const int j0 = (threadIdx.y / np) * cols_per_warp; + +#pragma unroll + for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) { + load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q); + } + } + + __syncthreads(); + + int kb0 = kb0_start; + + // Preload mask and K data for first iteration when using cp_async with multiple stages: + if constexpr (nstages > 1) { + static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline"); + constexpr bool use_cp_async = true; + constexpr bool oob_check = false; + constexpr int k_VKQ_sup = nbatch_fa; + if (ncols2 > 1 || mask_h) { + flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check> + (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01); + } + flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check> + (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); + } + + // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. + if constexpr (ncols2 == 1) { + constexpr bool oob_check = true; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, + T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + constexpr bool last_iter = true; + const int k_VKQ_sup = ne11 - kb0*nbatch_fa; + flash_attn_ext_f16_iter + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, + T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool oob_check = false; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, + T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + constexpr bool last_iter = true; + constexpr int k_VKQ_sup = nbatch_fa; + flash_attn_ext_f16_iter + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, + T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } + + // With multi-stage loading there is no __syncthreads at the end of the iter, + // there can be a race condition on shared memory access for combining/writing back results. + if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) { + __syncthreads(); + } + + // Finally, sum up partial KQ rowsums. + { +#if defined(TURING_MMA_AVAILABLE) + // The partial sums are spread across 8/4 threads. + constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; + constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#elif defined(AMD_WMMA_AVAILABLE) + // The partial sums are spread across 2 threads. + constexpr int offset_first = 16; + constexpr int offset_last = 16; +#else // Volta + // The partial sums are spread across 2 threads. + constexpr int offset_first = 2; + constexpr int offset_last = 2; +#endif // defined(TURING_MMA_AVAILABLE) +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { +#pragma unroll + for (int offset = offset_first; offset >= offset_last; offset >>= 1) { + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + } + } + } + + // If attention sinks are used, potentially re-scale if KQ_max is small. + // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // so it's being done unconditionally for every thread. + if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { + float KQ_max_scale[cols_per_thread]; +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); + const float sink = sinks_f[jc % ncols2]; + + const float KQ_max_new = fmaxf(KQ_max[col], sink); + const float KQ_max_diff = KQ_max[col] - KQ_max_new; + KQ_max_scale[col] = expf(KQ_max_diff); + KQ_max[col] = KQ_max_new; + + *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD; + + const float KQ_max_add = expf(sink - KQ_max_new); + KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add; + } + +#if defined(TURING_MMA_AVAILABLE) + if constexpr (cols_per_warp == 8) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::I; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { +#pragma unroll + for (int col = 0; col < cols_per_thread; ++col) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) { + VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#else // Volta + const int col = (threadIdx.x / 2) % 2; + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } +#endif // defined(TURING_MMA_AVAILABLE) + } + + // Combine VKQ accumulator values if np > 1. + // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. + // So also write VKQ accumulators to shared memory in column-major format if np == 1. + + constexpr int tile_stride = nbatch_combine + 4; + static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine"); + + if constexpr (cols_per_warp == 8) { + const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset + const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta + const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum + + if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) { + // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && threadIdx.x < T_B_KQ::I) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && threadIdx.x < T_B_KQ::I) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } else { + // jc_cwm = jc combine write meta + // KQ_cmr = KQ combine max rowsum + // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale. +#if defined(TURING_MMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); + const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); + const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#elif defined(AMD_WMMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0); + const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]); + const bool thread_should_write = threadIdx.x / 16 < cols_per_thread; +#else // Volta + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); + const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); + const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8; +#endif // defined(TURING_MMA_AVAILABLE) + + if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) { + ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr; + } + + __syncthreads(); + + if (np == 1) { + // No combination is needed, the meta data can be directly written from registers to VRAM. + if (needs_fixup && thread_should_write) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + if (is_fixup && thread_should_write) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[jc_cwm] = KQ_cmr; + } + } + } + + if (np > 1 && threadIdx.y % np == 0) { + // Combine the meta data for parallel warps via shared memory. + // Warps with threadIdx.y % np != 0 must NOT return early. + // All threads must return simultaneously to avoid race conditions with work on the next tile. + + constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; + float2 meta[nmeta]; +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; + } + + float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x); + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset < WARP_SIZE) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + } + } + + float KQ_cms[nmeta]; // KQ combine max scale per warp. +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn); + } + + float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps. +#pragma unroll + for (int imeta = 1; imeta < nmeta; ++imeta) { + KQ_crs += KQ_cms[imeta]*meta[imeta].y; + } +#pragma unroll + for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { + if (offset < WARP_SIZE) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + } + } + + __syncthreads(); + + // Write back combined meta data: +#pragma unroll + for (int imeta = 0; imeta < nmeta; ++imeta) { + if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + // Combined KQ max scale + rowsum. + meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + } + } + + // Combined KQ max + rowsum. + static_assert(cols_per_warp <= WARP_SIZE); + if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; + dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); + } + } else if (np > 1) { + // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch. + // Therefore, all other warps also need to execute a __syncthreads(). + // Otherwise the points at which warps synchronize with each other would become misaligned. + __syncthreads(); + } + +#pragma unroll + for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { + if constexpr (cols_per_warp == 8) { + const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { + const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format. + +#pragma unroll + for (int l = 0; l < T_B_KQ::ne; ++l) { + const int k = k1 + T_B_KQ::get_j(l); + + tile_Q[jc_cwd*tile_stride + k] = B.x[l]; + } + } + } else { + const int j0 = threadIdx.y*cols_per_warp; +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + } + } + } + + __syncthreads(); + + if (np == 1 || threadIdx.y % np == 0) { + // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. + // The values after that are for the partial results of the individual blocks. + float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); + +#pragma unroll + for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { + const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); + const int stride_jc = WARP_SIZE / stride_k; + + if (k0_start == k0_stop) { + continue; + } + +#pragma unroll + for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + + if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { + break; + } + + const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp; + + const int j_dst = jc_dst / ncols2; + const int c_dst = jc_dst % ncols2; + + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) { + continue; + } + + const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; +#pragma unroll + for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { + const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + + float2 dstk_val = make_float2(0.0f, 0.0f); +#pragma unroll + for (int ip = 0; ip < np; ++ip) { + const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0]; + const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]); + dstk_val.x += dstk_val_add.x*KQ_crs; + dstk_val.y += dstk_val_add.y*KQ_crs; + } + + if (!needs_fixup && !is_fixup) { + const float KQ_rowsum_j = meta_j[1]; + dstk_val.x /= KQ_rowsum_j; + dstk_val.y /= KQ_rowsum_j; + } + + if (is_fixup) { + dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val; + } else { + dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val; + } + } + } + } + } + if (np > 1) { + __syncthreads(); + } + } +#else + GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, + scale, slope, logit_softcap, ne01, ne02, gqa_ratio, + stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, + jt, kb0_start, kb0_stop); + NO_DEVICE_CODE; +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +} + +template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view> +__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + NO_DEVICE_CODE; + return; + } +#ifdef VOLTA_MMA_AVAILABLE + if (ncols1*ncols2 < 32) { + NO_DEVICE_CODE; + return; + } +#endif // VOLTA_MMA_AVAILABLE + +#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING + if (ncols1*ncols2 > 32) { + NO_DEVICE_CODE; + return; + } +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING + +#if defined(AMD_WMMA_AVAILABLE) + if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_WMMA_AVAILABLE) + + constexpr int ncols = ncols1 * ncols2; + constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); + constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); + constexpr int nwarps = nthreads / WARP_SIZE; + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int stride_Q1 = nb01 / sizeof(float2); + const int stride_Q2 = nb02 / sizeof(float2); + const int stride_K = nb11 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2); + + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + // kbc == k block continuous, current index in continuous ijk space. + int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + + // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. + // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). + // In the most general case >2 seams can fall into the same tile. + + // kb0 == k start index when in the output tile. + int kb0_start = kbc % iter_k; + int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); + + while (kbc < kbc_stop && kb0_stop == iter_k) { + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); + + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; + + if (KV_max) { + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); + } + constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + if (kb0_start == 0) { + constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup> + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); + } else { + constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup> + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); + } + + kbc += iter_k; + kbc -= kbc % iter_k; + + kb0_start = 0; + kb0_stop = min(iter_k, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index. + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); + const half * mask_h = ncols2 == 1 && !mask ? nullptr : + (const half *) (mask + nb33*(sequence % ne33)); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); + + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; + + if (KV_max) { + kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); + } + + constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + constexpr bool needs_fixup = false; + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup> + (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +} + +template <int DKQ, int DV, int ncols1, int ncols2> +void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + constexpr int ncols = ncols1 * ncols2; + + const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc); + const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc); + const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc); + const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc); + const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc); + const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); + const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); + + const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); + const int nwarps = nthreads / WARP_SIZE; + + constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu + + const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); + const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2); + const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2); + const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2); + + const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage; + + const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ? + std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) : + nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask); + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + +#if defined(GGML_USE_HIP) + using fattn_kernel_ptr_t = const void*; +#else + using fattn_kernel_ptr_t = fattn_kernel_t; +#endif // defined(GGML_USE_HIP) + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>; + +#if !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !defined(GGML_USE_MUSA) + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>; + +#if !defined(GGML_USE_MUSA) + static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; + if (!shared_memory_limit_raised[id]) { + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + shared_memory_limit_raised[id] = true; + } +#endif // !defined(GGML_USE_MUSA) + } + + launch_fattn<DV, ncols1, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); +} + + +#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \ + template void ggml_cuda_flash_attn_ext_mma_f16_case \ + <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \ + extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \ + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32) + +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) +DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) + +// The number of viable configurations for Deepseek is very limited: +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); + +// For GLM 4.7 Flash +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu new file mode 100644 index 0000000..3fcb09b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu @@ -0,0 +1,49 @@ +#include "common.cuh" +#include "fattn-tile.cuh" +#include "fattn-wmma-f16.cuh" + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; + case 64: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 72: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); + } break; + case 128: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); + } break; + case 256: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh new file mode 100644 index 0000000..b6db582 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh @@ -0,0 +1,1256 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +// nbatch_fa == number of KQ rows to process per iteration +// nbatch_K == number of K columns to load in parallel for KQ calculation + +// TODO optimize kernel parameters for FP16 NVIDIA (P100) +// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112 + +// The ROCm compiler cannot handle templating in __launch_bounds__. +// As a workaround, define a macro to package the kernel parameters as uint32_t: +#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + + return 0; +} + +static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); + } + if (fast_fp16_available(cc)) { + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +} + +static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef GGML_USE_HIP +#ifdef RDNA + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); +#endif // RDNA +#else +#ifdef FAST_FP16_AVAILABLE + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP +} + +static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +// TODO: deduplicate with mma-f16 +template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check> +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1<cpy_nb>( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<7>{}(load); +} + +template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check> +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); + + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + __align__(16) half2 tmp_h2[cpy_ne/2]; + ggml_cuda_memcpy_1<sizeof(tmp_h2)>( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + __align__(16) float2 tmp_f2[cpy_ne/2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = __half22float2(tmp_h2[l]); + } + ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K, + bool use_logit_softcap, bool oob_check, typename T_vec_dot> +static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check> + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) float Q_k[cpw][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + __syncthreads(); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K, + bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc> +static __device__ __forceinline__ void flash_attn_tile_iter( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask, + const uint3 ne01, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max, + const int col_Q_0) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory. + // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs]. +#ifdef FAST_FP16_AVAILABLE + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // FAST_FP16_AVAILABLE + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01); + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + +#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f; +#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET); + } + } + + KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + __syncthreads(); + } else { + static_assert(cpw == 1, "bad cpw"); + __shared__ float KQ_max_new_shared[nwarps]; + if (threadIdx.x == 0) { + KQ_max_new_shared[threadIdx.y] = KQ_max_new[0]; + } + __syncthreads(); + KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np]; + KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef FAST_FP16_AVAILABLE + __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#else + __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ? + expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; + KQ_sum_add += val; + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + ggml_cuda_memcpy_1<sizeof(tmp[0])>( + KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs, + tmp[i0/(np*warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check> + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + __align__(16) half2 V_k[(DVp/2)/warp_size]; + __align__(16) half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + __align__(16) half tmp[KQ_cs]; + ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>( + &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0]; + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + __align__(16) float2 V_k[(DVp/2)/warp_size]; + __align__(16) float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0]; + } + } + } +#endif // FAST_FP16_AVAILABLE + + __syncthreads(); + } +} + +template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size +__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + + if ( +#ifdef GGML_USE_WMMA_FATTN + (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || +#endif // GGML_USE_WMMA_FATTN + (use_logit_softcap && !(DV == 128 || DV == 256)) + ) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + constexpr int warp_size = 32; + constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on. + + const int sequence = blockIdx.z / (ne02/ncols2); + const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr; + + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols * DKQ/2]; + __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; + __shared__ half KQ[ncols * nbatch_fa]; + __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#else + __shared__ float Q_tmp[ncols * DKQ]; + __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; + __shared__ float KQ[ncols * nbatch_fa]; + __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[cpw]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { + __align__(16) float tmp_f[cpy_ne_D] = {0.0f}; + ggml_cuda_memcpy_1<sizeof(tmp_f)> + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef FAST_FP16_AVAILABLE + __align__(16) half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); +#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + tmp_h2[i1/2] *= make_half2(0.25f, 0.25f); +#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE) + } + ggml_cuda_memcpy_1<sizeof(tmp_h2)>( + &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], + tmp_h2); +#else + ggml_cuda_memcpy_1<sizeof(tmp_f)>( + &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D], + tmp_f); +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = blockIdx.y*nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check> + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); + k_VKQ_0 += gridDim.y*nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check> + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check> + (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_combine = (half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // FAST_FP16_AVAILABLE + float * KQ_sum_combine = (float *) Q_tmp; + + if (threadIdx.y % np != 0) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]); + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1<cpy_ne_D*4>( + &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // FAST_FP16_AVAILABLE + + if (threadIdx.x == 0) { + KQ_sum_combine[threadIdx.y] = KQ_sum[0]; + } + + return; + } + + __syncthreads(); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + __align__(16) half2 tmp[cpy_ne_D]; + ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + __align__(16) float tmp[cpy_ne_D]; + ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && blockIdx.y == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = fmaxf(KQ_max[jc0], sink); + const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j); + KQ_max[jc0] = KQ_max_new_j; + + const float val = expf(sink - KQ_max[jc0]); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) { + return; + } + + const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; + + const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + __align__(16) float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + tmp[i1].x *= scale; + tmp[i1].y *= scale; + } + if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { + ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale; + } + ggml_cuda_memcpy_1<cpy_ne_D*4>( + &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template <int DKQ, int DV, int ncols2, bool use_logit_softcap> +static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + + constexpr size_t nbytes_shared = 0; + +#ifdef GGML_USE_HIP + if constexpr (DV <= 128) { + if (Q->ne[1] > 32/ncols2) { + constexpr int cols_per_block = 64; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; + launch_fattn<DV, cols_per_block/ncols2, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + +#ifndef GGML_USE_HIP + if constexpr (DV <= 256) +#endif // GGML_USE_HIP + { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; + launch_fattn<DV, cols_per_block/ncols2, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; + launch_fattn<DV, cols_per_block/ncols2, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; + launch_fattn<DV, cols_per_block/ncols2, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; + launch_fattn<DV, cols_per_block/ncols2, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; + launch_fattn<DV, cols_per_block/ncols2, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + GGML_ABORT("fatal error"); +} + +template <int DKQ, int DV, bool use_logit_softcap> +static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); + const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template <int DKQ, int DV> +void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst); + } +} + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_cuda_flash_attn_ext_tile_case \ + <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 72, 72); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh new file mode 100644 index 0000000..3f4a78c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh @@ -0,0 +1,586 @@ +#include "common.cuh" +#include "fattn-common.cuh" + +static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) { + return 128; + GGML_UNUSED(cc); +} + +static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { + return 128; +} + +// Currenlty llvm with the amdgcn target does not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size +__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) +static __global__ void flash_attn_ext_vec( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + +#ifdef GGML_USE_HIP +#ifdef RDNA + constexpr int nthreads_KQ_q = 2; +#else + constexpr int nthreads_KQ_q = 4; +#endif // RDNA + constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); +#else + constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32); + constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32); +#endif // GGML_USE_HIP + + constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); + constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + + static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); + static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); + + constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; + + constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>(); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; +#ifdef V_DOT2_F32_F16_AVAILABLE + constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>(); +#else + constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>(); +#endif // V_DOT2_F32_F16_AVAILABLE + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); + constexpr int nwarps = nthreads / WARP_SIZE; + const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < nthreads); + + constexpr int ne_KQ = ncols*D; + constexpr int ne_combine = nwarps*V_cols_per_iter*D; +#ifdef V_DOT2_F32_F16_AVAILABLE + half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; +#else + float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine]; +#endif // V_DOT2_F32_F16_AVAILABLE + + float KQ_max[ncols]; + float KQ_sum[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max[j] = -FLT_MAX/2.0f; + KQ_sum[j] = 0.0f; + } + + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: +#ifdef V_DOT2_F32_F16_AVAILABLE + half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. +#else + __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. +#endif // V_DOT2_F32_F16_AVAILABLE + int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + if constexpr (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 1 && ic0 + j >= int(ne01.z)) { +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { + tmp_q_i32[i] = 0; + } + } + if (threadIdx.x < D/QK8_1) { + tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f); + } + } else { + const float * Q_f = (const float *) (Q + j*nb01); + constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE; +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) { + quantize_q8_1_to_shared<float2, nthreads_quantize> + (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1); + } + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ); + + Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i]; + Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1]; + } + } + + __syncthreads(); + } else { +#ifdef V_DOT2_F32_F16_AVAILABLE + const half2 scale_h2 = make_half2(scale, scale); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; + + __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; + if (ncols == 1 || ic0 + j < int(ne01.z)) { + ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]); + ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); + } +#pragma unroll + for (int i1 = 0; i1 < cpy_ne; ++i1) { + Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k] *= scale_h2; + } + } +#else +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const float2 * Q_j = (const float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; + if (ncols == 1 || ic0 + j < int(ne01.z)) { + ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); + ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k].x *= scale; + Q_reg[j][k].y *= scale; + } + } +#endif // V_DOT2_F32_F16_AVAILABLE + } + + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + K += blockIdx.y*nthreads * nb11; + V += blockIdx.y*nthreads * nb21; + maskh += blockIdx.y*nthreads; + for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads, + // Increment pointers after each loop: + K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) { + + // Calculate KQ tile and keep track of new maximum KQ values: + float KQ_reg[ncols]; // KQ in registers. + + float KQ_max_new[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max_new[j] = KQ_max[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) { + const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum<nthreads_KQ>(sum); + + if (use_logit_softcap) { + sum = logit_softcap*tanhf(sum); + } + + if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) { + sum += slope*__half2float(maskh[j*ne11 + i_KQ]); + } + + KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET); + + if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) { + KQ_reg[j] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) { + KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE)); + } + const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]); + KQ_max[j] = KQ_max_new[j]; + + KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]); + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; + KQ[j*nthreads + tid] = KQ_reg[j]; + +#ifdef V_DOT2_F32_F16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; + } +#endif // V_DOT2_F32_F16_AVAILABLE + } + +#ifndef GGML_USE_HIP + __syncwarp(); +#endif // GGML_USE_HIP + +#pragma unroll + for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) { + const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V); + +#ifdef V_DOT2_F32_F16_AVAILABLE + half2 KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = __half2half2(KQ[j*nthreads + k]); + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + half2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j]; + } + } + } +#else + float KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = KQ[j*nthreads + k]; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + float2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j]; + } + } + } +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + if (sinks && blockIdx.y == 0) { + const float sink = ((const float *) sinks)[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + const float kqmax_new_j = fmaxf(sink, KQ_max[j]); + const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j); + KQ_max[j] = kqmax_new_j; + + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f); + +#ifdef V_DOT2_F32_F16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale; + } +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + __shared__ float KQ_max_shared[ncols][WARP_SIZE]; + __shared__ float KQ_sum_shared[ncols][WARP_SIZE]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.y == 0) { + KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f; + KQ_sum_shared[j][threadIdx.x] = 0.0f; + } + } + + __syncthreads(); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (threadIdx.x == 0) { + KQ_max_shared[j][threadIdx.y] = KQ_max[j]; + } + } + __syncthreads(); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) { + break; + } + + float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x]; + kqmax_new = warp_reduce_max(kqmax_new); + const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new); + KQ_max[j_VKQ] = kqmax_new; + +#ifdef V_DOT2_F32_F16_AVAILABLE + half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); + + const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); + + ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + } +#else + float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2) + + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2); + +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale; + VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2); + + ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); + } +#endif // V_DOT2_F32_F16_AVAILABLE + + KQ_sum[j_VKQ] *= kqmax_scale; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + if (threadIdx.x == 0) { + KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ]; + } + + __syncthreads(); + + if (nthreads <= D || tid < D) { + KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x]; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nthreads) { + float dst_val = 0; +#pragma unroll + for (int w = 0; w < nwarps; ++w) { +#pragma unroll + for (int v = 0; v < V_cols_per_iter; ++v) { + dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]); + } + } + if (gridDim.y == 1) { + dst_val /= KQ_sum[j_VKQ]; + } + dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val; + } + } + + if (j_VKQ < ncols-1) { + __syncthreads(); + } + + } + + if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) { + dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> +void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc); + const int nwarps = nthreads / WARP_SIZE; + fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>; + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; + constexpr size_t nbytes_shared = 0; + launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); +} + +template <int D, ggml_type type_K, ggml_type type_V> +void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } + return; + } + + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } +} + +#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ + template void ggml_cuda_flash_attn_ext_vec_case \ + <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu new file mode 100644 index 0000000..8694fd0 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -0,0 +1,675 @@ +// Old and deprecated WMMA FlashAttention implementation. +// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing. +// Long-term the WMMA code should be replaced with a dedicated Volta implementation. + +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +#ifdef GGML_USE_WMMA_FATTN +#if !defined(GGML_USE_HIP) +#include <mma.h> +#if defined(GGML_USE_MUSA) +namespace wmma = mtmusa::wmma; +#else // GGML_USE_MUSA +namespace wmma = nvcuda::wmma; +#endif // GGML_USE_MUSA +#elif defined(GGML_USE_HIP) +#include <rocwmma/rocwmma.hpp> +namespace wmma = rocwmma; +#endif // !defined(GGML_USE_HIP) +#endif // GGML_USE_WMMA_FATTN + +// D == head size, VKQ_stride == num VKQ rows calculated in parallel: +template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap> +__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void flash_attn_ext_f16( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) + // Skip unused kernel variants for faster compilation: + if (use_logit_softcap && !(D == 128 || D == 256)) { + NO_DEVICE_CODE; + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on. + + static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); + static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); + constexpr int frag_m = ncols == 8 ? 32 : 16; + constexpr int frag_n = ncols == 8 ? 8 : 16; + static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); + typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K; + typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V; + typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b; + typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; + typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; + + constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. + constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. + static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); + + // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: + constexpr int D_padded = D + 8; + constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; + constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const half2 * mask2 = (const half2 *) maskh; + const float * sinksf = (const float *) sinks; + + const int stride_Q = nb01 / sizeof(float); + const int stride_KV = nb11 / sizeof(half); + + const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + const half slopeh = __float2half(slopef); + const half2 slope2 = make_half2(slopef, slopef); + + const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); + + frag_b Q_b[D/16][ncols/frag_n]; + + // A single buffer for temporarily holding tiles of KQ and VKQ parts: + constexpr int mem_KQ = ncols*kqs_padded*kqar; + constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; + __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; + float * KQ_f = (float *) KQ; + half2 * KQ2 = (half2 *) KQ; + + float KQ_rowsum_f[ncols/nwarps] = {0.0f}; + float KQ_max_f[ncols/nwarps]; + float KQ_max_scale_f[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_f[j] = -FLT_MAX/2.0f; + } + + half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + half2 KQ_max_h2[ncols/nwarps]; + half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; + +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); + } + + __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. + half2 * VKQ2 = (half2 *) VKQ; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) { + break; + } + VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); + } + } + + // Convert Q to half and apply scale, temporarily store in KQ: +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D && i >= D) { + break; + } + KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f; + } + } + + __syncthreads(); + + // Load Q into tensor core fragments/registers since it will be used frequently: +#pragma unroll + for (int i0 = 0; i0 < D; i0 += 16) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + } + } + + __syncthreads(); + + // Iterate over ne11 == previous tokens: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { + // Calculate tile of KQ: +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { + frag_c_KQ KQ_c[ncols/frag_n]; +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f)); + } +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { + frag_a_K K_a; + wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); + } + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major); + } + } + + __syncthreads(); + + // Calculate softmax for each KQ column using the current max. value. + // The divisor is stored in KQ_rowsum and will be applied at the end. +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same<KQ_acc_t, float>::value) { + float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k]; + + if (use_logit_softcap) { + KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]); + } + } + + float KQ_max_new = KQ_max_f[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { + const int k = k0 + threadIdx.x; + + KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? + __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; + KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET); + } + KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new); + + const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; + KQ_max_scale_f[j0/nwarps] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_max_scale_f[j0/nwarps] = 0.0f; + } + KQ_max_f[j0/nwarps] = KQ_max_new; + + float KQ_rowsum_add = 0.0f; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) { + const int k = k0 + threadIdx.x; + + const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps]; + KQ_f_tmp[k0/warp_size] = expf(diff); + if (diff <= SOFTMAX_FTZ_THRESHOLD) { + KQ_f_tmp[k0/warp_size] = 0.0f; + } + KQ_rowsum_add += KQ_f_tmp[k0/warp_size]; + KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size]; + } + KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; + } else { + half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k]; + + if (use_logit_softcap) { + // There is no dedicated tangens hyperbolicus function for half2. + KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f)); + KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f)) + /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f)); + + KQ2_tmp[k0/warp_size] *= logit_softcap_2; + } + } + + half2 KQ_max_new = KQ_max_h2[j0/nwarps]; +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { + const int k = k0 + threadIdx.x; + + KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); + KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]); + } + KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); + const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; + KQ_max_scale_h2[j0/nwarps] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; + KQ_max_h2[j0/nwarps] = KQ_max_new; + + half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) { + const int k = k0 + threadIdx.x; + + const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps]; + KQ2_tmp[k0/warp_size] = h2exp(diff); + const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); + *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask; + KQ_rowsum_add += KQ2_tmp[k0/warp_size]; + KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size]; + } + KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add); + + // Scale previous KQ_rowsum to account for a potential increase in KQ_max: + KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; + } + } + + __syncthreads(); + + frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + wmma::load_matrix_sync( + KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], + KQ + j0*(kqar*kqs_padded) + k, + kqar*kqs_padded); + } + } + + frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f)); + } + +#pragma unroll + for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { + const int k = k0 + (threadIdx.y % VKQ_ratio)*16; + + frag_a_V v_a; + wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); +#pragma unroll + for (int j = 0; j < ncols/frag_n; ++j) { + wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); + } + } + } + + __syncthreads(); + + const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += frag_n) { + wmma::store_matrix_sync( + KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], + D_padded, wmma::mem_col_major); + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + half2 VKQ_scale; + if (std::is_same<KQ_acc_t, float>::value) { + VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); + } else { + VKQ_scale = KQ_max_scale_h2[j0/nwarps]; + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) { + break; + } + + half2 VKQ_add = make_half2(0.0f, 0.0f); +#pragma unroll + for (int l = 0; l < VKQ_ratio; ++l) { + VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; + } + VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; + } + } + + __syncthreads(); + } + + // Apply attention sinks + if (sinksf && blockIdx.y == 0) { + const float sinkf = sinksf[head]; + const half sinkh = __float2half(sinkf); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (std::is_same<KQ_acc_t, float>::value) { + float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf); + + const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new); + KQ_max_f[j0/nwarps] = kqmax_new; + + KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]); + + const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) break; + VKQ2[j*(D_padded/2) + i] *= scale_h2; + } + } else { + half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]); + half kqmax_new = fmaxf(kqmax_old, sinkh); + KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new); + + const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new); + const half2 KQ_max_scale = __half2half2(KQ_max_scale_h); + + KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale; + const half val = hexp(sinkh - kqmax_new); + KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val); + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D/2 && i >= D/2) break; + VKQ2[j*(D_padded/2) + i] *= KQ_max_scale; + } + } + } + + __syncthreads(); + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j_VKQ = j0 + threadIdx.y; + if (ic0 + j_VKQ >= int(ne01.z)) { + return; + } + + float KQ_rowsum_j; + if (std::is_same<KQ_acc_t, float>::value) { + KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; + } else { + KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); + } + + const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += warp_size) { + const int i = i0 + threadIdx.x; + if (i0 + warp_size > D && i >= D) { + break; + } + float dst_val = VKQ[j_VKQ*D_padded + i]; + if (gridDim.y == 1) { + dst_val /= KQ_rowsum_j; + } + dst[j_dst_unrolled*D + i] = dst_val; + } + + if (gridDim.y == 1 || threadIdx.x != 0) { + continue; + } + + float2 dst_meta_val; + if (std::is_same<KQ_acc_t, float>::value) { + dst_meta_val.x = KQ_max_f[j0/nwarps]; + } else { + dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); + } + dst_meta_val.y = KQ_rowsum_j; + dst_meta[j_dst_unrolled] = dst_meta_val; + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) +} + +constexpr int get_max_power_of_2(int x) { + return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; +} + +static_assert(get_max_power_of_2(1) == 1, "Test failed."); +static_assert(get_max_power_of_2(2) == 2, "Test failed."); +static_assert(get_max_power_of_2(4) == 4, "Test failed."); +static_assert(get_max_power_of_2(6) == 2, "Test failed."); + +// Number of VKQ rows calculated in parallel: +constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { + return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; +} + +static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); +static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed."); +static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed."); +static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed."); +static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed."); + +template <int D, int cols_per_block, typename KQ_acc_t> +void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + constexpr int nwarps = 4; + + constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + fattn_kernel_t fattn_kernel; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; + } else { + constexpr bool use_logit_softcap = true; + fattn_kernel = flash_attn_ext_f16< + D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>; + } + launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size); +} + +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; + + if (prec != GGML_PREC_DEFAULT) { + if (Q->ne[1] <= 32 || Q->ne[0] > 128) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + } else { + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); + break; + // case 256: + // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); + // break; + default: + GGML_ABORT("fatal error"); + break; + } + } + return; + } + +#if !defined(GGML_USE_HIP) + if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) { + constexpr int cols_per_block = 8; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } +#endif // !defined(GGML_USE_HIP) + + if (Q->ne[1] <= 32) { + constexpr int cols_per_block = 16; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } + return; + } + + constexpr int cols_per_block = 32; + switch (Q->ne[0]) { + case 64: + ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); + break; + case 80: + ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); + break; + case 96: + ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); + break; + case 112: + ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); + break; + case 128: + ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); + break; + case 256: + ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh new file mode 100644 index 0000000..cd3bfd4 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -0,0 +1,51 @@ +#pragma once + +#include "common.cuh" + +#if defined(GGML_USE_MUSA) +#define GGML_USE_WMMA_FATTN +#endif // defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) +#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#define GGML_USE_WMMA_FATTN +#elif defined(CDNA) +#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance" +#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#if defined(RDNA3) +#define GGML_USE_WMMA_FATTN +#endif // defined(RDNA3) +#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#define GGML_USE_WMMA_FATTN +#elif defined(RDNA4) +#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + +// WMMA flash attention requires FP16 matrix instructions to be available for ggml code. +static bool ggml_cuda_should_use_wmma_fattn(const int cc) { +#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_CDNA(cc)){ +#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + } else { + return false; + } +#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) +} + +void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn.cu b/llama.cpp/ggml/src/ggml-cuda/fattn.cu new file mode 100644 index 0000000..721edd9 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn.cu @@ -0,0 +1,482 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-mma-f16.cuh" +#include "fattn-tile.cuh" +#include "fattn-vec.cuh" +#include "fattn-wmma-f16.cuh" +#include "fattn.cuh" + +template <int DKQ, int DV, int ncols2> +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const ggml_tensor * Q = dst->src[0]; + + if constexpr (ncols2 <= 8) { + if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst); + return; + } + } + + if constexpr (ncols2 <= 16) { + if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst); + return; + } + } + + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst); +} + +template <int DKQ, int DV> +static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers + // are put into the template specialization without GQA optimizations. + bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr || ggml_is_quantized(t->type)) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + use_gqa_opt = false; + break; + } + } + } + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute: + if (cc == GGML_CUDA_CC_VOLTA) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio > 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio > 2) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio > 1) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst); +} + +static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + switch (Q->ne[0]) { + case 64: + GGML_ASSERT(V->ne[0] == 64); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst); + break; + case 80: + GGML_ASSERT(V->ne[0] == 80); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst); + break; + case 96: + GGML_ASSERT(V->ne[0] == 96); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst); + break; + case 112: + GGML_ASSERT(V->ne[0] == 112); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst); + break; + case 128: + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); + break; + case 256: + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); + break; + case 576: { + // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. + GGML_ASSERT(V->ne[0] == 512); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (gqa_ratio == 20) { // GLM 4.7 Flash + if (cc >= GGML_CUDA_CC_DGX_SPARK) { + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_BLACKWELL) { + if (Q->ne[1] <= 4 && K->ne[1] >= 65536) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_TURING) { + if (Q->ne[1] <= 4) { + if (K->ne[1] <= 16384) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + // Volta: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } else if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } + } break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +#define FATTN_VEC_CASE(D, type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \ + ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \ + return; \ + } \ + } \ + +#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ + FATTN_VEC_CASE( 64, type_K, type_V) \ + FATTN_VEC_CASE(128, type_K, type_V) \ + FATTN_VEC_CASE(256, type_K, type_V) \ + +static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * Q = dst->src[0]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; + +#ifdef GGML_CUDA_FA_ALL_QUANTS + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) +#else + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) +#endif // GGML_CUDA_FA_ALL_QUANTS + + GGML_ABORT("fatal error"); +} + +// Best FlashAttention kernel for a specific GPU: +enum best_fattn_kernel { + BEST_FATTN_KERNEL_NONE = 0, + BEST_FATTN_KERNEL_TILE = 200, + BEST_FATTN_KERNEL_VEC = 100, + BEST_FATTN_KERNEL_WMMA_F16 = 300, + BEST_FATTN_KERNEL_MMA_F16 = 400, +}; + +static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { +#ifndef FLASH_ATTN_AVAILABLE + GGML_UNUSED(device); GGML_UNUSED(dst); + return BEST_FATTN_KERNEL_NONE; +#endif// FLASH_ATTN_AVAILABLE + + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + // The effective batch size for the kernel can be increased by gqa_ratio. + // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr || ggml_is_quantized(t->type)) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + gqa_opt_applies = false; + break; + } + } + } + + const int cc = ggml_cuda_info().devices[device].cc; + + switch (K->ne[0]) { + case 40: + case 64: + case 72: + case 80: + case 96: + case 128: + case 112: + case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 576: + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; + } + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + +#ifndef GGML_CUDA_FA_ALL_QUANTS + if (K->type != V->type) { + return BEST_FATTN_KERNEL_NONE; + } +#endif // GGML_CUDA_FA_ALL_QUANTS + + switch (K->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: +#ifndef GGML_CUDA_FA_ALL_QUANTS + return BEST_FATTN_KERNEL_NONE; +#endif // GGML_CUDA_FA_ALL_QUANTS + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + if (mask && mask->ne[2] != 1) { + return BEST_FATTN_KERNEL_NONE; + } + + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + + // If Turing tensor cores are available, use them: + if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (can_use_vector_kernel) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { + return BEST_FATTN_KERNEL_VEC; + } + } else { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } else { + if (Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + if (!gqa_opt_applies && Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; + } + } + return BEST_FATTN_KERNEL_MMA_F16; + } + + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + if (Q->ne[1] * gqa_ratio_eff <= 16) { + return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices. + } + return BEST_FATTN_KERNEL_MMA_F16; + } + + // Use the WMMA kernel if possible: + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (can_use_vector_kernel && Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + return BEST_FATTN_KERNEL_WMMA_F16; + } + + if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (can_use_vector_kernel) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (Q->ne[1] * gqa_ratio_eff <= 8) { + return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized. + } + return BEST_FATTN_KERNEL_MMA_F16; + } + + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + return BEST_FATTN_KERNEL_TILE; +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_set_device(ctx.device); + switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { + case BEST_FATTN_KERNEL_NONE: + GGML_ABORT("fatal error"); + case BEST_FATTN_KERNEL_TILE: + ggml_cuda_flash_attn_ext_tile(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC: + ggml_cuda_flash_attn_ext_vec(ctx, dst); + break; + case BEST_FATTN_KERNEL_WMMA_F16: + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + break; + case BEST_FATTN_KERNEL_MMA_F16: + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + break; + } +} + +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) { + return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; +} diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn.cuh new file mode 100644 index 0000000..78705d5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fattn.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/fill.cu b/llama.cpp/ggml/src/ggml-cuda/fill.cu new file mode 100644 index 0000000..739062c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fill.cu @@ -0,0 +1,37 @@ +#include "fill.cuh" +#include "convert.cuh" + +#define CUDA_FILL_BLOCK_SIZE 256 + +template <typename T> +static __global__ void fill_kernel(T * dst, const int64_t k, const T value) { + const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = value; +} + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value); + break; + case GGML_TYPE_F16: + fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value)); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/fill.cuh b/llama.cpp/ggml/src/ggml-cuda/fill.cuh new file mode 100644 index 0000000..8443c83 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/fill.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/getrows.cu b/llama.cpp/ggml/src/ggml-cuda/getrows.cu new file mode 100644 index 0000000..2fab332 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/getrows.cu @@ -0,0 +1,286 @@ +#include "getrows.cuh" +#include "dequantize.cuh" +#include "convert.cuh" + +template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> +static __global__ void k_get_rows( + const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; + + const int ib = i00/qk; // block index + const int iqs = (i00%qk)/qr; // quant index + const int iybs = i00 - i00%qk; // dst block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + float2 v; + dequantize_kernel(src0_row, ib, iqs, v); + + dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x); + dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y); + } + } +} + +template<typename src0_t, typename dst_t> +static __global__ void k_get_rows_float( + const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ + /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, + /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, + const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + + for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { + // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. + const int i10 = blockIdx.x; + const int i11 = z / ne12; // TODO fastdiv + const int i12 = z % ne12; + + if (i00 >= ne00) { + return; + } + + const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; + + dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; + const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03); + + dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]); + } + } +} + +template<typename grad_t, typename dst_t> +static __global__ void k_get_rows_back_float( + const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) { + const int col = blockIdx.x*blockDim.x + threadIdx.x; + + if (col >= ncols) { + return; + } + + const int dst_row = blockIdx.y*blockDim.y + threadIdx.y; + + float sum = 0.0f; + + for (int64_t i = 0; i < nrows_grad; ++i) { + if (rows[i] != dst_row) { + continue; + } + sum += grad[i*ncols + col]; + } + + dst[dst_row*ncols + col] = sum; +} + +template<int qk, int qr, dequantize_kernel_t dq, typename dst_t> +static void get_rows_cuda_q( + const void * src0_d, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); + const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); + + // strides in elements + // const size_t s0 = nb0 / sizeof(dst_t); + const size_t s1 = nb1 / sizeof(dst_t); + const size_t s2 = nb2 / sizeof(dst_t); + const size_t s3 = nb3 / sizeof(dst_t); + + const size_t s10 = nb10 / sizeof(int32_t); + const size_t s11 = nb11 / sizeof(int32_t); + const size_t s12 = nb12 / sizeof(int32_t); + // const size_t s13 = nb13 / sizeof(int32_t); + + GGML_ASSERT(ne00 % 2 == 0); + + k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>( + src0_d, src1_d, dst_d, + ne00, /*ne01, ne02, ne03,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); +} + +template<typename src0_t, typename dst_t> +static void get_rows_cuda_float( + const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); + const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; + const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX)); + + // strides in elements + // const size_t s0 = nb0 / sizeof(dst_t); + const size_t s1 = nb1 / sizeof(dst_t); + const size_t s2 = nb2 / sizeof(dst_t); + const size_t s3 = nb3 / sizeof(dst_t); + + const size_t s10 = nb10 / sizeof(int32_t); + const size_t s11 = nb11 / sizeof(int32_t); + const size_t s12 = nb12 / sizeof(int32_t); + // const size_t s13 = nb13 / sizeof(int32_t); + + k_get_rows_float<<<block_nums, block_dims, 0, stream>>>( + src0_d, src1_d, dst_d, + ne00, /*ne01, ne02, ne03,*/ + /*ne10,*/ ne11, ne12, /*ne13,*/ + /* s0,*/ s1, s2, s3, + /* nb00,*/ nb01, nb02, nb03, + s10, s11, s12/*, s13*/); +} + +template <typename dst_t> +static void ggml_cuda_get_rows_switch_src0_type( + const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d, + const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + switch (src0_type) { + case GGML_TYPE_F16: + get_rows_cuda_float((const half *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_F32: + get_rows_cuda_float((const float *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_I32: + get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_BF16: + get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q4_0: + get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q4_1: + get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q5_0: + get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q5_1: + get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_Q8_0: + get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + default: + // TODO: k-quants + GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type)); + break; + } +} + +void get_rows_cuda( + const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type, + int64_t ne00, size_t nb01, size_t nb02, size_t nb03, + int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12, + size_t nb1, size_t nb2, size_t nb3, + cudaStream_t stream) { + switch (dst_type) { + case GGML_TYPE_F32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_I32: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_F16: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + case GGML_TYPE_BF16: + ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; + default: + GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type)); + break; + } +} + +void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(ne13 == 1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type)); + GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type)); + + get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); +} + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output + const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass + + GGML_TENSOR_BINARY_OP_LOCALS + + const float * src0_d = (const float *) src0->data; + const int32_t * src1_d = (const int32_t *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(ne02*ne03 == 1); + GGML_ASSERT(ne12*ne13 == 1); + GGML_ASSERT(ne2*ne3 == 1); + + const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1); + const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE; + const dim3 block_nums(block_num_x, ne1, 1); + + k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/getrows.cuh b/llama.cpp/ggml/src/ggml-cuda/getrows.cuh new file mode 100644 index 0000000..3c5bea5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/getrows.cuh @@ -0,0 +1,15 @@ +#include "common.cuh" + +#define CUDA_GET_ROWS_BLOCK_SIZE 256 +#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256 + +void get_rows_cuda( + const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type, + int64_t ne00, size_t nb01, size_t nb02, size_t nb03, + int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12, + size_t nb1, size_t nb2, size_t nb3, + cudaStream_t stream); + +void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu b/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu new file mode 100644 index 0000000..b163468 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu @@ -0,0 +1,5118 @@ +#include "ggml-cuda.h" +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-cuda/common.cuh" +#include "ggml-cuda/acc.cuh" +#include "ggml-cuda/add-id.cuh" +#include "ggml-cuda/arange.cuh" +#include "ggml-cuda/argmax.cuh" +#include "ggml-cuda/argsort.cuh" +#include "ggml-cuda/binbcast.cuh" +#include "ggml-cuda/clamp.cuh" +#include "ggml-cuda/concat.cuh" +#include "ggml-cuda/conv-transpose-1d.cuh" +#include "ggml-cuda/conv2d.cuh" +#include "ggml-cuda/conv2d-dw.cuh" +#include "ggml-cuda/conv2d-transpose.cuh" +#include "ggml-cuda/convert.cuh" +#include "ggml-cuda/count-equal.cuh" +#include "ggml-cuda/cpy.cuh" +#include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" +#include "ggml-cuda/diagmask.cuh" +#include "ggml-cuda/diag.cuh" +#include "ggml-cuda/fattn.cuh" +#include "ggml-cuda/getrows.cuh" +#include "ggml-cuda/im2col.cuh" +#include "ggml-cuda/mmf.cuh" +#include "ggml-cuda/mmq.cuh" +#include "ggml-cuda/mmvf.cuh" +#include "ggml-cuda/mmvq.cuh" +#include "ggml-cuda/norm.cuh" +#include "ggml-cuda/opt-step-adamw.cuh" +#include "ggml-cuda/opt-step-sgd.cuh" +#include "ggml-cuda/out-prod.cuh" +#include "ggml-cuda/pad.cuh" +#include "ggml-cuda/pool2d.cuh" +#include "ggml-cuda/quantize.cuh" +#include "ggml-cuda/rope.cuh" +#include "ggml-cuda/roll.cuh" +#include "ggml-cuda/scale.cuh" +#include "ggml-cuda/softcap.cuh" +#include "ggml-cuda/softmax.cuh" +#include "ggml-cuda/ssm-conv.cuh" +#include "ggml-cuda/ssm-scan.cuh" +#include "ggml-cuda/sum.cuh" +#include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/top-k.cuh" +#include "ggml-cuda/mean.cuh" +#include "ggml-cuda/tsembd.cuh" +#include "ggml-cuda/topk-moe.cuh" +#include "ggml-cuda/unary.cuh" +#include "ggml-cuda/upscale.cuh" +#include "ggml-cuda/wkv.cuh" +#include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set.cuh" +#include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/pad_reflect_1d.cuh" +#include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/tri.cuh" +#include "ggml-cuda/cumsum.cuh" +#include "ggml-cuda/fill.cuh" +#include "ggml.h" + +#include <algorithm> +#include <array> +#include <atomic> +#include <charconv> +#include <cinttypes> +#include <condition_variable> +#include <cstddef> +#include <cstdint> +#include <cfloat> +#include <initializer_list> +#include <limits> +#include <map> +#include <memory> +#include <mutex> +#include <cstdarg> +#include <cstdio> +#include <cstdlib> +#include <string> +#include <vector> +#include <unordered_set> + +static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); + +[[noreturn]] +void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { + int id = -1; // in case cudaGetDevice fails + (void)cudaGetDevice(&id); + + GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg); + GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); + GGML_LOG_ERROR(" %s\n", stmt); + // abort with GGML_ABORT to get a stack trace + GGML_ABORT(GGML_CUDA_NAME " error"); +} + +// this is faster on Windows +// probably because the Windows CUDA libraries forget to make this check before invoking the drivers +void ggml_cuda_set_device(int device) { + int current_device; + CUDA_CHECK(cudaGetDevice(¤t_device)); + + if (device == current_device) { + return; + } + + CUDA_CHECK(cudaSetDevice(device)); +} + +int ggml_cuda_get_device() { + int id; + CUDA_CHECK(cudaGetDevice(&id)); + return id; +} + +static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { + ggml_cuda_set_device(device); + cudaError_t err; + if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) { + err = cudaMallocManaged(ptr, size); +#if defined(GGML_USE_HIP) + if (err == hipSuccess) { + CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + } + + // fall back to cudaMalloc if not supported (e.g. on Windows) + if (err == hipErrorNotSupported) { + static bool warned_unsupported = false; + if (!warned_unsupported) { + GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n"); + warned_unsupported = true; + } + + err = cudaMalloc(ptr, size); + } +#endif // defined(GGML_USE_HIP) + } else { + err = cudaMalloc(ptr, size); + } + return err; +} + +#if defined(GGML_USE_HIP) +static int ggml_cuda_parse_id(char devName[]) { + // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp + // these values are not stable so this is susceptible to breakage + // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp + int archMajor = 0x0; + int archMinor = 0x0; + int archNum = GGML_CUDA_CC_OFFSET_AMD; + int archLen = strlen(devName); + char archName[archLen + 1]; + + // strip leading 'gfx' while copying into our buffer + if (archLen > 3) { + strcpy(archName, &devName[3]); + archLen -= 3; + } + + // trim trailing :xnack- or :sramecc- statuses + archLen = strcspn(archName, ":"); + archName[archLen] = '\0'; + + // tease out the version information + if (archLen > 8) { + // versions labeled generic use '-' as delimiter + // strip the trailing "-generic" then iterate through what remains + if ((strstr(archName, "-generic"))) { + archName[archLen - 8] = '\0'; + char * pch; + if ((pch = strtok(archName, "-"))) { + archMajor = (int)strtoul(pch, 0, 16); + if ((pch = strtok(NULL, "-"))) { + archMinor = 0x10 * (int)strtoul(pch, 0, 16); + } + } + } + } else if (archLen >= 3) { + // last two digits should be the minor * 0x10 + stepping + archMinor = (int)strtoul(&archName[archLen - 2], 0, 16); + archName[archLen - 2] = '\0'; + + // only the major version remains + archMajor = (int)strtoul(archName, 0, 16); + } + archNum += archMajor * 0x100; + archNum += archMinor; + return archNum; +} +#endif // defined(GGML_USE_HIP) + +static ggml_cuda_device_info ggml_cuda_init() { + ggml_cuda_device_info info = {}; + + cudaError_t err = cudaGetDeviceCount(&info.device_count); + if (err != cudaSuccess) { + GGML_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err)); + return info; + } + + GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); + + int64_t total_vram = 0; + GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); + + std::vector<std::pair<int, std::string>> turing_devices_without_mma; + for (int id = 0; id < info.device_count; ++id) { + int device_vmm = 0; + +#if defined(GGML_USE_VMM) + CUdevice device; + CU_CHECK(cuDeviceGet(&device, id)); + CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); + + if (device_vmm) { + CUmemAllocationProp alloc_prop = {}; + alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + alloc_prop.location.id = id; + CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + } +#endif // defined(GGML_USE_VMM) + info.devices[id].vmm = !!device_vmm; + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); + + info.default_tensor_split[id] = total_vram; + total_vram += prop.totalGlobalMem; + info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034) + info.devices[id].nsm = prop.multiProcessorCount; + info.devices[id].smpb = prop.sharedMemPerBlock; + info.devices[id].warp_size = prop.warpSize; + +#ifndef GGML_USE_MUSA + int supports_coop_launch = 0; + CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id)); + info.devices[id].supports_cooperative_launch = !!supports_coop_launch; +#else + info.devices[id].supports_cooperative_launch = false; +#endif // !(GGML_USE_MUSA) +#if defined(GGML_USE_HIP) + info.devices[id].smpbo = prop.sharedMemPerBlock; + + info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName); + if ((info.devices[id].cc & 0xff00) == 0x0) { + GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n", + id, prop.name, prop.gcnArchName, prop.major, prop.minor); + + // Fallback to prop.major and prop.minor + if (prop.major > 0) { + info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; + } + } + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", + id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, + device_vmm ? "yes" : "no", prop.warpSize); +#elif defined(GGML_USE_MUSA) + // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs. + info.devices[id].warp_size = 32; + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; + info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100; + info.devices[id].cc += prop.minor * 0x10; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); +#else + info.devices[id].smpbo = prop.sharedMemPerBlockOptin; + info.devices[id].cc = 100*prop.major + 10*prop.minor; + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); + std::string device_name(prop.name); + if (device_name == "NVIDIA GeForce MX450") { + turing_devices_without_mma.push_back({ id, device_name }); + } else if (device_name == "NVIDIA GeForce MX550") { + turing_devices_without_mma.push_back({ id, device_name }); + } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") { + turing_devices_without_mma.push_back({ id, device_name }); + } + + // Temporary performance fix: + // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls. + // TODO: Check for future drivers the default scheduling strategy and + // remove this call again when cudaDeviceScheduleSpin is default. + if (prop.major == 12 && prop.minor == 1) { + CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin)); + } + +#endif // defined(GGML_USE_HIP) + } + + if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) { + GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n"); + for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) { + GGML_LOG_INFO( + " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str()); + } + GGML_LOG_INFO( + "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n"); + } + + for (int id = 0; id < info.device_count; ++id) { + info.default_tensor_split[id] /= total_vram; + } + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + + return info; +} + +const ggml_cuda_device_info & ggml_cuda_info() { + static ggml_cuda_device_info info = ggml_cuda_init(); + return info; +} + +// #define DEBUG_CUDA_MALLOC + +// buffer pool for cuda (legacy) +struct ggml_cuda_pool_leg : public ggml_cuda_pool { + static const int MAX_BUFFERS = 256; + + int device; + struct ggml_cuda_buffer { + void * ptr = nullptr; + size_t size = 0; + }; + + ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {}; + size_t pool_size = 0; + + explicit ggml_cuda_pool_leg(int device) : + device(device) { + } + + ~ggml_cuda_pool_leg() { + ggml_cuda_set_device(device); + for (int i = 0; i < MAX_BUFFERS; ++i) { + ggml_cuda_buffer & b = buffer_pool[i]; + if (b.ptr != nullptr) { + CUDA_CHECK(cudaFree(b.ptr)); + pool_size -= b.size; + } + } + GGML_ASSERT(pool_size == 0); + } + + void * alloc(size_t size, size_t * actual_size) override { +#ifdef DEBUG_CUDA_MALLOC + int nnz = 0; + size_t max_size = 0; +#endif + size_t best_diff = 1ull << 36; + int ibest = -1; + for (int i = 0; i < MAX_BUFFERS; ++i) { + ggml_cuda_buffer& b = buffer_pool[i]; + if (b.ptr != nullptr) { +#ifdef DEBUG_CUDA_MALLOC + ++nnz; + if (b.size > max_size) max_size = b.size; +#endif + if (b.size >= size) { + size_t diff = b.size - size; + if (diff < best_diff) { + best_diff = diff; + ibest = i; + if (!best_diff) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + } + } + } + if (ibest >= 0) { + ggml_cuda_buffer& b = buffer_pool[ibest]; + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + void * ptr; + size_t look_ahead_size = (size_t) (1.05 * size); + look_ahead_size = 256 * ((look_ahead_size + 255)/256); + ggml_cuda_set_device(device); + CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + *actual_size = look_ahead_size; + pool_size += look_ahead_size; +#ifdef DEBUG_CUDA_MALLOC + GGML_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz, + (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024)); +#endif + return ptr; + } + + void free(void * ptr, size_t size) override { + for (int i = 0; i < MAX_BUFFERS; ++i) { + ggml_cuda_buffer& b = buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n"); + ggml_cuda_set_device(device); + CUDA_CHECK(cudaFree(ptr)); + pool_size -= size; + } +}; + +// pool with virtual memory +#if defined(GGML_USE_VMM) +struct ggml_cuda_pool_vmm : public ggml_cuda_pool { + static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB + + int device; + CUdeviceptr pool_addr = 0; + size_t pool_used = 0; + size_t pool_size = 0; + size_t granularity; +#if defined(GGML_USE_HIP) + std::vector<std::pair<CUdeviceptr, size_t>> mappings; +#endif + + explicit ggml_cuda_pool_vmm(int device) : + device(device), + granularity(ggml_cuda_info().devices[device].vmm_granularity) { + } + + ~ggml_cuda_pool_vmm() { + if (pool_addr != 0) { +#if defined(GGML_USE_HIP) + // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285 + for (std::pair<CUdeviceptr, size_t> & mapping : mappings) { + CU_CHECK(cuMemUnmap(mapping.first, mapping.second)); + } +#else + CU_CHECK(cuMemUnmap(pool_addr, pool_size)); +#endif + CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE)); + } + } + + void * alloc(size_t size, size_t * actual_size) override { + // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types + const size_t alignment = 128; + size = alignment * ((size + alignment - 1) / alignment); + + size_t avail = pool_size - pool_used; + + if (size > avail) { + // round up to the next multiple of the granularity + size_t reserve_size = size - avail; + reserve_size = granularity * ((reserve_size + granularity - 1) / granularity); + + GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); + + // allocate more physical memory + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = device; + CUmemGenericAllocationHandle handle; + CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0)); + + // reserve virtual address space (if not already reserved) + if (pool_addr == 0) { + CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); + } + + // map at the end of the pool + CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size); + CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0)); +#if defined(GGML_USE_HIP) + mappings.push_back({start_ptr, reserve_size}); +#endif + + // the memory allocation handle is no longer needed after mapping + CU_CHECK(cuMemRelease(handle)); + + // set access + CUmemAccessDesc access = {}; + access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + access.location.id = device; + access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1)); + + // add to the pool + pool_size += reserve_size; + + //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n", + // device, (unsigned long long) (pool_size/1024/1024), + // (unsigned long long) (reserve_size/1024/1024)); + } + + GGML_ASSERT(pool_addr != 0); + + void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used)); + *actual_size = size; + pool_used += size; + +#ifdef DEBUG_CUDA_MALLOC + printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr); +#endif + + return ptr; + } + + void free(void * ptr, size_t size) override { +#ifdef DEBUG_CUDA_MALLOC + printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr); +#endif + + pool_used -= size; + + // all deallocations must be in reverse order of the allocations + GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used)); + } +}; +#endif // defined(GGML_USE_VMM) + +std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device, + [[maybe_unused]] int stream_no) { +#if defined(GGML_USE_VMM) + if (ggml_cuda_info().devices[device].vmm) { + return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device)); + } +#endif // defined(GGML_USE_VMM) + return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device)); +} + +// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error +// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured + +static std::mutex ggml_cuda_lock; +static std::condition_variable ggml_cuda_lock_cv; +static std::atomic<int> ggml_cuda_lock_counter; + +ggml_backend_cuda_context::~ggml_backend_cuda_context() { + std::unique_lock<std::mutex> lock(ggml_cuda_lock); + ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; }); + + if (copy_event != nullptr) { + CUDA_CHECK(cudaEventDestroy(copy_event)); + } + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { + for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { + if (streams[i][j] != nullptr) { + CUDA_CHECK(cudaStreamDestroy(streams[i][j])); + } + } + if (cublas_handles[i] != nullptr) { + CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); + } + } +} + + +// cuda buffer + +struct ggml_backend_cuda_buffer_context { + int device; + void * dev_ptr = nullptr; + std::string name; + + ggml_backend_cuda_buffer_context(int device, void * dev_ptr) : + device(device), dev_ptr(dev_ptr), + name(GGML_CUDA_NAME + std::to_string(device)) { + } + + ~ggml_backend_cuda_buffer_context() { + CUDA_CHECK(cudaFree(dev_ptr)); + } +}; + +static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + delete ctx; +} + +static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer; +} + +static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + return ctx->dev_ptr; +} + +static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + if (tensor->view_src != NULL) { + assert(tensor->view_src->buffer->buft == buffer->buft); + return GGML_STATUS_SUCCESS; + } + + if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + // initialize padding to 0 to avoid possible NaN values + const size_t original_size = ggml_nbytes(tensor); + const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor); + + if (padded_size > original_size) { + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size)); + } + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_cuda(src->buffer)) { + ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context; + if (src_ctx->device == dst_ctx->device) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread)); + } else { +#ifdef GGML_CUDA_NO_PEER_COPY + return false; +#else + CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread)); +#endif + } + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); + return true; + } + return false; + + GGML_UNUSED(buffer); +} + +static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { + /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, + /* .get_base = */ ggml_backend_cuda_buffer_get_base, + /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, + /* .clear = */ ggml_backend_cuda_buffer_clear, + /* .reset = */ NULL, +}; + +// cuda buffer type +struct ggml_backend_cuda_buffer_type_context { + int device; + std::string name; +}; + +static const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name; +} + +static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; + + ggml_cuda_set_device(buft_ctx->device); + + void * dev_ptr; + cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); + if (err != cudaSuccess) { + // clear the error + (void)cudaGetLastError(); + GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err)); + return nullptr; + } + + ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); + + return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); +} + +static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + size_t size = ggml_nbytes(tensor); + int64_t ne0 = tensor->ne[0]; + + if (ggml_is_quantized(tensor->type)) { + if (ne0 % MATRIX_ROW_PADDING != 0) { + GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor)); + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + } + + return size; + + GGML_UNUSED(buft); +} + +static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { + /* .get_name = */ ggml_backend_cuda_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + if (device >= ggml_backend_cuda_get_device_count()) { + return nullptr; + } + + static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES]; + + static bool ggml_backend_cuda_buffer_type_initialized = false; + + if (!ggml_backend_cuda_buffer_type_initialized) { + for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) { + ggml_backend_cuda_buffer_types[i] = { + /* .iface = */ ggml_backend_cuda_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i), + /* .context = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)}, + }; + } + ggml_backend_cuda_buffer_type_initialized = true; + } + + return &ggml_backend_cuda_buffer_types[device]; +} + +// cuda split buffer + +static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) { + int64_t row_rounding = 0; + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) { + continue; + } + + const int cc = ggml_cuda_info().devices[id].cc; + row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc)); + } + return row_rounding; +} + +static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) { + const int64_t nrows = ggml_nrows(tensor); + const int64_t rounding = get_row_rounding(tensor_split); + + *row_low = id == 0 ? 0 : nrows*tensor_split[id]; + *row_low -= *row_low % rounding; + + if (id == ggml_backend_cuda_get_device_count() - 1) { + *row_high = nrows; + } else { + *row_high = nrows*tensor_split[id + 1]; + *row_high -= *row_high % rounding; + } +} + +static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]); +} + +struct ggml_backend_cuda_split_buffer_type_context { + int main_device; + std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split; + std::string name; +}; + +struct ggml_backend_cuda_split_buffer_context { + ~ggml_backend_cuda_split_buffer_context() { + for (ggml_tensor_extra_gpu * extra : tensor_extras) { + for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) { + for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) { + if (extra->events[id][is] != nullptr) { + CUDA_CHECK(cudaEventDestroy(extra->events[id][is])); + } + } + if (extra->data_device[id] != nullptr) { + CUDA_CHECK(cudaFree(extra->data_device[id])); + } + } + delete extra; + } + } + + std::vector<ggml_tensor_extra_gpu *> tensor_extras; +}; + + +static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; + delete ctx; +} + +static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) { + // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced + return (void *)0x1000; + + GGML_UNUSED(buffer); +} + +static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); + + ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context; + ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + ctx->tensor_extras.push_back(extra); + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + // FIXME: do not crash if cudaMalloc fails + // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first + ggml_cuda_set_device(id); + char * buf; + CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id)); + + // set padding to 0 to avoid possible NaN values + if (size > original_size) { + CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size)); + } + + extra->data_device[id] = buf; + + for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) { + CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming)); + } + } + tensor->extra = extra; + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + // split tensors must always be set in their entirety at once + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); + + ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + const size_t nb1 = tensor->nb[1]; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + const char * buf_host = (const char *)data + offset_split; + CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + } + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); + } +} + +static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + // split tensors must always be set in their entirety at once + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); + + ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context; + + const int64_t ne0 = tensor->ne[0]; + const size_t nb1 = tensor->nb[1]; + ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra; + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + const size_t offset_split = row_low*nb1; + size_t size = ggml_nbytes_split(tensor, nrows_split); + const size_t original_size = size; + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + + char * buf_host = (char *)data + offset_split; + CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + } + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); + } +} + +static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_UNUSED(buffer); + GGML_UNUSED(value); +} + +static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { + /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer, + /* .get_base = */ ggml_backend_cuda_split_buffer_get_base, + /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_cuda_split_buffer_clear, + /* .reset = */ NULL, +}; + +// cuda split buffer type + +static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name; +} + +static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point + // instead, we allocate them for each tensor separately in init_tensor + // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated, + // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct. + ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context(); + + return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size); +} + +static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 128; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context; + GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors"); + + size_t total_size = 0; + + const int64_t ne0 = tensor->ne[0]; + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + int64_t row_low, row_high; + get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id); + + int64_t nrows_split = row_high - row_low; + if (nrows_split == 0) { + continue; + } + + total_size += ggml_nbytes_split(tensor, nrows_split); + + // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses + if (ne0 % MATRIX_ROW_PADDING != 0) { + total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + } + + return total_size; +} + +static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + return false; + + GGML_UNUSED(buft); +} + +static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = { + /* .get_name = */ ggml_backend_cuda_split_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cuda_split_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, +}; + +ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + static std::map<std::pair<int, std::array<float, GGML_CUDA_MAX_DEVICES>>, struct ggml_backend_buffer_type> buft_map; + + std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {}; + + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; }); + if (all_zero) { + tensor_split_arr = ggml_cuda_info().default_tensor_split; + } else { + float split_sum = 0.0f; + for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { + tensor_split_arr[i] = split_sum; + split_sum += tensor_split[i]; + } + for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { + tensor_split_arr[i] /= split_sum; + } + } + + auto it = buft_map.find({main_device, tensor_split_arr}); + if (it != buft_map.end()) { + return &it->second; + } + auto * ctx = new ggml_backend_cuda_split_buffer_type_context{ + main_device, + tensor_split_arr, + GGML_CUDA_NAME + std::to_string(main_device) + "_Split", + }; + + struct ggml_backend_buffer_type buft { + /* .iface = */ ggml_backend_cuda_split_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device), + /* .context = */ ctx, + }; + + auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft); + return &result.first->second; +} + +// host buffer type + +static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_CUDA_NAME "_Host"; + + GGML_UNUSED(buft); +} + +static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name; +} + +static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + CUDA_CHECK(cudaFreeHost(buffer->context)); +} + +static void * ggml_cuda_host_malloc(size_t size) { + if (getenv("GGML_CUDA_NO_PINNED") != nullptr) { + return nullptr; + } + + void * ptr = nullptr; + cudaError_t err = cudaMallocHost((void **) &ptr, size); + if (err != cudaSuccess) { + // clear the error + (void)cudaGetLastError(); + GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, + size / 1024.0 / 1024.0, cudaGetErrorString(err)); + return nullptr; + } + + return ptr; +} + +static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + void * ptr = ggml_cuda_host_malloc(size); + + if (ptr == nullptr) { + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; + + return buffer; +} + +ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cuda_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, + /* .get_max_size = */ NULL, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0), + /* .context = */ nullptr, + }; + + return &ggml_backend_cuda_buffer_type_host; +} + +//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) { +// return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name; +//} + +/// kernels + +typedef void (*ggml_cuda_op_mul_mat_t)( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); + +#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE +#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128 +#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE + +#define MUL_MAT_SRC1_COL_STRIDE 128 + +static cudaError_t ggml_cuda_cpy_tensor_2d( + void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { + + const char * src_ptr = (const char *) src->data; + char * dst_ptr = (char *) dst; + + const int64_t ne0 = src->ne[0]; + const int64_t nb0 = src->nb[0]; + const int64_t nb1 = src->nb[1]; + const int64_t nb2 = src->nb[2]; + const int64_t nb3 = src->nb[3]; + const enum ggml_type type = src->type; + const int64_t ts = ggml_type_size(type); + const int64_t bs = ggml_blck_size(type); + const int64_t i1_diff = i1_high - i1_low; + + const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; + if (nb0 == ts && nb1 == ts*ne0/bs) { + return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream); + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream); + } else { + for (int64_t i1 = 0; i1 < i1_diff; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream); + if (r != cudaSuccess) { + return r; + } + } + return cudaSuccess; + } +} + +static void ggml_cuda_op_mul_mat_cublas( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src0_dd_i != nullptr); + GGML_ASSERT(src1_ddf_i != nullptr); + GGML_ASSERT(dst_dd_i != nullptr); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + + int id = ggml_cuda_get_device(); + + // the main device has a larger memory buffer to hold the results from all GPUs + // ldc == nrows of the matrix that cuBLAS writes into + int64_t ldc = id == ctx.device ? ne0 : row_diff; + + const int cc = ggml_cuda_info().devices[id].cc; + + const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); + + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + + if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { + ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id)); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type); + GGML_ASSERT(to_bf16_cuda != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_bf16.alloc(ne); + to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream); + } + const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get(); + const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i; + ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols); + + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f32, src0_ptr, CUDA_R_16BF, ne00, + src1_ptr, CUDA_R_16BF, ne10, + &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } else if (fast_fp16_hardware_available(cc) && use_fp16) { + // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32 + ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id)); + if (src0->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = row_diff*ne00; + src0_as_f16.alloc(ne); + to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream); + } + const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get(); + + ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id)); + if (src1->type != GGML_TYPE_F16) { + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + size_t ne = src1_ncols*ne10; + src1_as_f16.alloc(ne); + to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream); + } + const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get(); + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + const float alpha = 1.0f; + const float beta = 0.0f; + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta, dst_dd_i, CUDA_R_32F, ldc, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols); + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + CUBLAS_CHECK( + cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream); + } + } else { + ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id)); + ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id)); + + if (src0->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + src0_ddq_as_f32.alloc(row_diff*ne00); + to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream); + } + if (src1->type != GGML_TYPE_F32) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type); + GGML_ASSERT(to_fp32_cuda != nullptr); + src1_ddq_as_f32.alloc(src1_ncols*ne10); + to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream); + } + + const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); + const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + CUBLAS_CHECK( + cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, + row_diff, src1_ncols, ne10, + &alpha, src0_ddf_i, ne00, + src1_ddf1_i, ne10, + &beta, dst_dd_i, ldc)); + } + + GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); +} + +static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { + static bool peer_access_enabled = false; + + const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; + + if (peer_access_enabled == enable_peer_access) { + return; + } + +#ifdef NDEBUG + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + ggml_cuda_set_device(id); + CUDA_CHECK(cudaDeviceSynchronize()); + } + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + ggml_cuda_set_device(id); + + for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) { + if (id == id_other) { + continue; + } + if (id != main_device && id_other != main_device) { + continue; + } + + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + if (enable_peer_access) { + cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); + if (err != cudaErrorPeerAccessAlreadyEnabled) { + CUDA_CHECK(err); + } else { + // reset the error + (void)cudaGetLastError(); + } + } else { + cudaError_t err = cudaDeviceDisablePeerAccess(id_other); + if (err != cudaErrorPeerAccessNotEnabled) { + CUDA_CHECK(err); + } else { + // reset the error + (void)cudaGetLastError(); + } + } + } + } + } + + ggml_cuda_set_device(main_device); +#endif // NDEBUG + + peer_access_enabled = enable_peer_access; + + GGML_UNUSED(main_device); +} + +static cudaError_t ggml_cuda_Memcpy2DPeerAsync( + void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices + cudaMemcpy3DPeerParms p = {}; + p.dstDevice = dstDevice; + p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height); + p.srcDevice = srcDevice; + p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height); + p.extent = make_cudaExtent(width, height, 1); + return cudaMemcpy3DPeerAsync(&p, stream); +#else + // HIP does not support cudaMemcpy3DPeerAsync or vmm pools + GGML_UNUSED(dstDevice); + GGML_UNUSED(srcDevice); + return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +static void ggml_cuda_op_mul_mat( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, + quantize_cuda_t quantize_src1) { + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int64_t nrows1 = ggml_nrows(src1); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // const int64_t nb10 = src1->nb[0]; + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; + + const int64_t nb2 = dst->nb[2]; + const int64_t nb3 = dst->nb[3]; + + ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context; + ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context; + + GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1)); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const int64_t i02_divisor = ne12 / ne02; + const int64_t i03_divisor = ne13 / ne03; + + const size_t src0_ts = ggml_type_size(src0->type); + const size_t src0_bs = ggml_blck_size(src0->type); + const size_t q8_1_ts = sizeof(block_q8_1); + const size_t q8_1_bs = QK8_1; + + const bool src0_is_contiguous = ggml_is_contiguous(src0); + const bool src1_is_contiguous = ggml_is_contiguous(src1); + + const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); + GGML_ASSERT(!(split && ne02 > 1)); + GGML_ASSERT(!(split && ne03 > 1)); + GGML_ASSERT(!(split && ne02 < ne12)); + GGML_ASSERT(!(split && ne03 < ne13)); + + ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr; + + + std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split; + if (split) { + ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; + tensor_split = buft_ctx->tensor_split; + } + + struct dev_data { + int cc; + + ggml_cuda_pool_alloc<char> src0_dd_alloc; + ggml_cuda_pool_alloc<float> src1_ddf_alloc; + ggml_cuda_pool_alloc<char> src1_ddq_alloc; + ggml_cuda_pool_alloc<float> dst_dd_alloc; + + char * src0_dd = nullptr; + float * src1_ddf = nullptr; // float + char * src1_ddq = nullptr; // q8_1 + float * dst_dd = nullptr; + + int64_t row_low; + int64_t row_high; + }; + + dev_data dev[GGML_CUDA_MAX_DEVICES]; + + int used_devices = 0; + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + dev[id].cc = ggml_cuda_info().devices[id].cc; + + // by default, use all rows + dev[id].row_low = 0; + dev[id].row_high = ne01; + + // for multi GPU, get the row boundaries from tensor split + // and round to mul_mat_q tile sizes + if (split) { + const int64_t rounding = get_row_rounding(tensor_split); + + if (id != 0) { + dev[id].row_low = ne01*tensor_split[id]; + if (dev[id].row_low < ne01) { + dev[id].row_low -= dev[id].row_low % rounding; + } + } + + if (id != ggml_backend_cuda_get_device_count() - 1) { + dev[id].row_high = ne01*tensor_split[id + 1]; + if (dev[id].row_high < ne01) { + dev[id].row_high -= dev[id].row_high % rounding; + } + } + } + } + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) { + continue; + } + + used_devices++; + + const bool src1_on_device = id == src1_ctx->device; + const bool dst_on_device = id == dst_ctx->device; + + ggml_cuda_set_device(id); + cudaStream_t stream = ctx.stream(id, 0); + + if (src0_is_contiguous) { + dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data; + } else { + // If src0 is not contiguous it will be copied to a temporary buffer. + // This buffer needs to be cleared entirely because multiple regions will function as padding. + const size_t nbytes_data = ggml_nbytes(src0); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding); + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream)); + } + + // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared: + if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); + const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); + CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream)); + } + + if (src1_on_device && src1_is_contiguous) { + dev[id].src1_ddf = (float *) src1->data; + } else { + dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1)); + } + + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq); + } + dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); + + if (src1_on_device && src1_is_contiguous) { + quantize_src1( + dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10, + nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float), + src1_padded_col_size, ne11, ne12, ne13, stream); + CUDA_CHECK(cudaGetLastError()); + } + } + + if (dst_on_device) { + dev[id].dst_dd = (float *) dst->data; + } else { + const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst); + dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf); + } + } + + // if multiple devices are used they need to wait for the main device + // here an event is recorded that signals that the main device has finished calculating the input data + if (split && used_devices > 1) { + ggml_cuda_set_device(ctx.device); + CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream())); + } + + const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11; + for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) { + const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0; + const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride; + + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) { + continue; + } + + const bool src1_on_device = id == src1_ctx->device; + const bool dst_on_device = id == dst_ctx->device; + const int64_t row_diff = dev[id].row_high - dev[id].row_low; + + ggml_cuda_set_device(id); + cudaStream_t stream = ctx.stream(id, is); + + // wait for main GPU data if necessary + if (split && (id != ctx.device || is != 0)) { + CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0)); + } + + for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) { + const int64_t i03 = i0 / ne12; + const int64_t i02 = i0 % ne12; + + size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq); + } else { + src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs; + } + + // for split tensors the data begins at i0 == i0_offset_low + const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs; + char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix; + float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10; + char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset; + float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff); + + // the main device memory buffer can be on VRAM scratch, with space for all partial results + // in that case an offset on dst_ddf_i is needed + if (id == ctx.device) { + dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split + } + + // copy src0, src1 to device if necessary + if (src1_is_contiguous) { + if (id != ctx.device) { + if (quantize_src1) { + char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + const size_t pitch = ne11*sizeof(block_q8_1_mmq); + const size_t width = src1_ncols*sizeof(block_q8_1_mmq); + const size_t height = src1_padded_col_size/(4*QK8_1); + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream)); + } else { + CUDA_CHECK(cudaMemcpyPeerAsync( + src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + } + } else { + float * src1_ddf_i_source = (float *) src1->data; + src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; + CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device, + src1_ncols*ne10*sizeof(float), stream)); + } + } + } else if (src1_on_device && !src1_is_contiguous) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); + } else { + GGML_ABORT("fatal error"); + } + + if (quantize_src1 && !src1_is_contiguous) { + quantize_src1( + src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10, + src1_padded_col_size, src1_ncols, 1, 1, stream); + CUDA_CHECK(cudaGetLastError()); + } + + if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) { + CUDA_CHECK(ggml_cuda_cpy_tensor_2d( + src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream)); + } + + // do the computation + op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i, + dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream); + CUDA_CHECK(cudaGetLastError()); + + // copy dst to host or other device if necessary + if (!dst_on_device) { + void * dst_off_device = dst->data; + if (split) { + // src0 = weight matrix is saved as a transposed matrix for better memory layout. + // dst is NOT transposed. + // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU. + // Instead they need to be copied to the correct slice in ne0 = dst row index. + // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results. + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0 + dev[id].row_low; + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync( + dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream)); + } else { + float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); + GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); + dhf_dst_i += src1_col_0*ne0; + CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream)); + } + } + + // add event for the main device to wait on until other device is done + if (split && (id != ctx.device || is != 0)) { + CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream)); + } + } + } + } + + // main device waits for all other devices to be finished + if (split && ggml_backend_cuda_get_device_count() > 1) { + int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE; + is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS; + + ggml_cuda_set_device(ctx.device); + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + if (dev[id].row_low == dev[id].row_high) { + continue; + } + for (int64_t is = 0; is < is_max; ++is) { + CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0)); + } + } + } +} + +static __global__ void k_compute_batched_ptrs( + const void * src0_as_f16, const void * src1_as_f16, char * dst, + const void ** ptrs_src, void ** ptrs_dst, + int64_t ne12, int64_t ne13, + int64_t ne23, + size_t nb02, size_t nb03, + size_t nb12, size_t nb13, + size_t nbd2, size_t nbd3, + int64_t r2, int64_t r3) { + const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y; + + if (i13 >= ne13 || i12 >= ne12) { + return; + } + + const int64_t i03 = i13 / r3; + const int64_t i02 = i12 / r2; + + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3; +} + +// Type traits for mapping ggml types to CUDA/cuBLAS types +template<ggml_type T> +struct batched_mul_mat_traits; + +template<> +struct batched_mul_mat_traits<GGML_TYPE_F32> { + using cuda_type = float; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_32F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F32; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits<GGML_TYPE_BF16> { + using cuda_type = nv_bfloat16; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static inline const cudaDataType_t data_type = CUDA_R_16BF; + static inline const ggml_type ggml_type_val = GGML_TYPE_BF16; + static inline const float alpha = 1.0f; + static inline const float beta = 0.0f; + static inline const void* get_alpha() { static const float val = alpha; return &val; } + static inline const void* get_beta() { static const float val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); } +}; + +template<> +struct batched_mul_mat_traits<GGML_TYPE_F16> { + using cuda_type = half; + static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + static inline const cudaDataType_t data_type = CUDA_R_16F; + static inline const ggml_type ggml_type_val = GGML_TYPE_F16; + static inline const half alpha = 1.0; + static inline const half beta = 0.0; + static inline const void* get_alpha() { static const half val = alpha; return &val; } + static inline const void* get_beta() { static const half val = beta; return &val; } + static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); } +}; + +template<ggml_type src0_type> +static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + using traits = batched_mul_mat_traits<src0_type>; + using cuda_t = typename traits::cuda_type; + + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft)); + GGML_ASSERT(src0->type == src0_type); + GGML_ASSERT(ggml_is_contiguous(dst)); + + // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst. + // As long as dst is contiguous this does not matter though. + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t ne_dst = ggml_nelements(dst); + cudaStream_t main_stream = ctx.stream(); + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream)); + + float * dst_ddf = (float *) dst->data; + const size_t ts_src1 = ggml_type_size(src1->type); + GGML_ASSERT(nb10 == ts_src1); + int64_t s11 = nb11 / ts_src1; + int64_t s12 = nb12 / ts_src1; + int64_t s13 = nb13 / ts_src1; + + const cuda_t * src0_ptr = nullptr; + const cuda_t * src1_ptr = nullptr; + + ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool()); + ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool()); + + bool is_src0_cont_2 = ggml_is_contiguous_2(src0); + bool is_src1_cont_2 = ggml_is_contiguous_2(src1); + + // Handle src0 + src0_ptr = (const cuda_t *) src0->data; + + // Handle src1 - convert if necessary + if (src1->type == src0_type) { + src1_ptr = (const cuda_t *) src1->data; + } else { + // Convert src1 to target type using traits conversion functions + const int64_t ne_src1 = ggml_nelements(src1); + src1_alloc.alloc(ne_src1); + + const auto convert_func = traits::get_nc_converter(src1->type); + GGML_ASSERT(convert_func != nullptr); + convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream); + src1_ptr = src1_alloc.get(); + s11 = ne10; + s12 = ne11*s11; + s13 = ne12*s12; + + is_src1_cont_2 = true; + } + + // Setup destination buffer + ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool()); + char * dst_t; + size_t nbd2 = dst->nb[2]; + size_t nbd3 = dst->nb[3]; + + cublasComputeType_t cu_compute_type = traits::compute_type; + cudaDataType_t cu_data_type = traits::data_type; + cudaDataType_t cu_data_type_a = traits::data_type; + cudaDataType_t cu_data_type_b = traits::data_type; + const void * alpha = traits::get_alpha(); + const void * beta = traits::get_beta(); + const float alpha_f32 = 1.0f; + const float beta_f32 = 0.0f; + + if (dst->op_params[0] == GGML_PREC_DEFAULT) { + if constexpr (src0_type == GGML_TYPE_F32) { + dst_t = (char *) dst_ddf; // Direct F32 output + } else { + dst_t = (char *) dst_temp.alloc(ne_dst); + nbd2 /= sizeof(float) / sizeof(cuda_t); + nbd3 /= sizeof(float) / sizeof(cuda_t); + } + } else { + dst_t = (char *) dst_ddf; + cu_compute_type = CUBLAS_COMPUTE_32F; + cu_data_type = CUDA_R_32F; + alpha = &alpha_f32; + beta = &beta_f32; + } + + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + cu_compute_type = CUBLAS_COMPUTE_32F; + alpha = &alpha_f32; + beta = &beta_f32; + } + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) { + // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3: + const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00; + const int64_t smb = ne12 == 1 ? s13 : s12; + + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + // use cublasGemmStridedBatchedEx + CUBLAS_CHECK( + cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA + src1_ptr, cu_data_type_b, s11, smb, // strideB + beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC + ne12*ne13, + cu_compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + // use cublasGemmBatchedEx + const int64_t ne23 = ne12*ne13; + + ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23); + ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23); + + size_t src1_stride_size = sizeof(cuda_t); + + const int threads_x = 16; + const int threads_y = 16; + dim3 block_dims(threads_x, threads_y); + + dim3 grid_dims( + (ne13 + threads_x - 1) / threads_x, + (ne12 + threads_y - 1) / threads_y + ); + k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>( + src0_ptr, src1_ptr, dst_t, + ptrs_src.get(), ptrs_dst.get(), + ne12, ne13, + ne23, + nb02, nb03, + (src1->type == src0_type) ? nb12 : s12*src1_stride_size, + (src1->type == src0_type) ? nb13 : s13*src1_stride_size, + nbd2, nbd3, + r2, r3); + + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00, + (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11, + beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0, + ne23, + cu_compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + + // Convert output back to F32 if needed + if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) { + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val); + to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream); + } +} + +static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32); + + switch (src0->type) { + case GGML_TYPE_F32: + ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst); + break; + case GGML_TYPE_BF16: + ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst); + break; + case GGML_TYPE_F16: + ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst); + break; + default: + GGML_ABORT("Unsupported type"); + } +} + +static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, + const ggml_tensor * ffn_gate, + const ggml_tensor * glu, + const ggml_tensor * ffn_up_bias = nullptr, + const ggml_tensor * ffn_gate_bias = nullptr) { + const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr; + + if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) { + return false; + } + + const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU; + const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU; + + GGML_ASSERT(ffn_up && ffn_gate && glu); + + if (!is_mul_mat && !is_mul_mat_id) { + return false; + } + + const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (has_bias) { + if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) { + return false; + } + + if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) { + return false; + } + + if (expected_bias_op == GGML_OP_ADD) { + const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up; + const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate; + if (!up_has_mul || !gate_has_mul) { + return false; + } + } else { // GGML_OP_ADD_ID + if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) { + return false; + } + if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) { + return false; + } + } + } else { + if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) { + return false; + } + } + + if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) || + !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) { + return false; + } + + if (ffn_up->src[1] != ffn_gate->src[1]) { + return false; + } + + if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) { + return false; + } + + static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI }; + + if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) { + return false; + } + + if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) { + return false; + } + + const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) || + ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft); + + //TODO: add support for fusion for split buffers + if (split) { + return false; + } + + return true; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID; + + bool use_mul_mat_vec_f = + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) && + src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]); + + const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) || + ggml_backend_buft_is_cuda_split(src1->buffer->buft); + + //TODO: add support for fusion for split buffers + if (split) { + return false; + } + + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; + } + + + return use_mul_mat_vec_f; +} + +static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + const ggml_tensor * dst = tensor; + + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && + ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && + src0->view_src; + + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + + // fusion is not universally faster on Pascal + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (cc <= GGML_CUDA_CC_PASCAL) { + return false; + } + //we only support fusion for ncols_dst = 1 + if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) { + return false; + } + + if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) { + return false; + } + + + const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) || + ggml_backend_buft_is_cuda_split(src1->buffer->buft); + + //TODO: add support for fusion for split buffers + if (split) { + return false; + } + + return use_mul_mat_vec_q; +} + +static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); + + // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q. + // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data. + // Therefore, in such cases use cuBLAS. + const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE + && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; + + bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + bool use_mul_mat_f = !ggml_is_quantized(src0->type) + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; + + bool any_gpus_with_slow_fp16 = false; + + if (split) { + ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; + auto & tensor_split = buft_ctx->tensor_split; + for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + // skip devices that are not going to do any work: + if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) { + continue; + } + + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); + } + } else { + const int cc = ggml_cuda_info().devices[ctx.device].cc; + const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); + use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); + use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); + } + + // debug helpers + //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); + //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); + //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); + //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); + //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); + //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + + //TODO update for generic tensor parallelism + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16); + bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); + bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; + + if (!split && use_mul_mat_vec_f) { + // the custom F16 vector kernel can be used over batched cuBLAS GEMM + // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_f) { + ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_vec_q) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst); + } else if (!split && use_mul_mat_q) { + ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst); + } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // general KQ + KQV multi-batch without FlashAttention + ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); + } else if (use_mul_mat_vec_f) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr); + } else if (use_mul_mat_vec_q) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); + } else if (use_mul_mat_q) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); + } else { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); + } +} + +static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * ids = dst->src[2]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers"); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); + if (ne2 <= MMVQ_MAX_BATCH_SIZE) { + if (ggml_is_quantized(src0->type)) { + if (ne2 <= 4) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + return; + } + } else { + if (GGML_CUDA_CC_IS_AMD(cc)) { + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + return; + } + } + } + + if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { + ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); + return; + } + + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) { + ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); + return; + } + } + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); + + const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc)) + || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type; + const ggml_type type_dst_sorted = GGML_TYPE_F32; + const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted); + const size_t ts_dst_sorted = ggml_type_size(type_dst_sorted); + + const int64_t n_expert_used = ids->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; + + std::vector<int32_t> ids_to_sorted_host; + ids_to_sorted_host.reserve(2*ne_get_rows); + std::vector<int32_t> ids_from_sorted_host(ne_get_rows); + + ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows); + + std::vector<int32_t> tokens_per_expert(ne02); + + ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted); + ggml_cuda_pool_alloc<char> dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted); + + std::vector<char> ids_host(ggml_nbytes(ids)); + CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices + for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens + for (int64_t iex = 0; iex < n_expert_used; ++iex) { + const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]); + assert(expert_to_use >= 0 && expert_to_use < ne02); + if (expert_to_use == i02) { + ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size(); + ids_to_sorted_host.push_back(i12*ne11 + iex % ne11); + tokens_per_expert[i02]++; + break; + } + } + } + } + GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows)); + + ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end()); + + CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); + + const int32_t * ids_to_sorted = ids_buf_dev.ptr + 0*ne_get_rows; + const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows; + + get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted, + ne10, nb11, nb12, nb13, + ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t), + ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream); + CUDA_CHECK(cudaGetLastError()); + + char * src1_data_cur = (char *) src1_sorted.ptr; + char * dst_data_cur = (char *) dst_sorted.ptr; + for (int64_t i02 = 0; i02 < ne02; ++i02) { + if (tokens_per_expert[i02] == 0) { + continue; + } + + ggml_tensor src0_slice = *src0; + src0_slice.ne[2] = 1; + src0_slice.nb[3] = src0_slice.nb[2]; + src0_slice.op = GGML_OP_VIEW; + src0_slice.view_src = dst->src[0]; // non-const pointer to src0 + src0_slice.data = (char *) src0->data + i02*nb02; + + ggml_tensor src1_slice; + memset(&src1_slice, 0, sizeof(src1_slice)); + src1_slice.buffer = src1->buffer; + src1_slice.type = type_src1_sorted; + src1_slice.ne[0] = ne10; + src1_slice.ne[1] = tokens_per_expert[i02]; + src1_slice.ne[2] = 1; + src1_slice.ne[3] = 1; + src1_slice.nb[0] = ts_src1_sorted; + src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0]; + src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1]; + src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2]; + src1_slice.data = src1_data_cur; + + ggml_tensor dst_slice; + memset(&dst_slice, 0, sizeof(dst_slice)); + dst_slice.buffer = dst->buffer; + dst_slice.type = type_dst_sorted; + dst_slice.ne[0] = ne0; + dst_slice.ne[1] = tokens_per_expert[i02]; + dst_slice.ne[2] = 1; + dst_slice.ne[3] = 1; + dst_slice.nb[0] = ts_dst_sorted; + dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0]; + dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1]; + dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2]; + dst_slice.data = dst_data_cur; + + ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice); + CUDA_CHECK(cudaGetLastError()); + + src1_data_cur += src1_slice.nb[2]; + dst_data_cur += dst_slice.nb[2]; + } + + get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type, + ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, + ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t), + nb1, nb2, nb3, stream); +} + +static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { + // why is this here instead of mul_mat? + if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { + ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); + } + + switch (dst->op) { + case GGML_OP_ARGMAX: + ggml_cuda_argmax(ctx, dst); + break; + case GGML_OP_COUNT_EQUAL: + ggml_cuda_count_equal(ctx, dst); + break; + case GGML_OP_REPEAT: + ggml_cuda_op_repeat(ctx, dst); + break; + case GGML_OP_REPEAT_BACK: + ggml_cuda_op_repeat_back(ctx, dst); + break; + case GGML_OP_GET_ROWS: + ggml_cuda_op_get_rows(ctx, dst); + break; + case GGML_OP_GET_ROWS_BACK: + ggml_cuda_op_get_rows_back(ctx, dst); + break; + case GGML_OP_SET_ROWS: + ggml_cuda_op_set_rows(ctx, dst); + break; + case GGML_OP_SET: + ggml_cuda_op_set(ctx, dst); + break; + case GGML_OP_DUP: + ggml_cuda_dup(ctx, dst); + break; + case GGML_OP_CPY: + ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]); + break; + case GGML_OP_CONT: + ggml_cuda_dup(ctx, dst); + break; + case GGML_OP_ADD: + case GGML_OP_ADD1: // TODO: more efficient implementation + ggml_cuda_op_add(ctx, dst); + break; + case GGML_OP_ADD_ID: + ggml_cuda_op_add_id(ctx, dst); + break; + case GGML_OP_SUB: + ggml_cuda_op_sub(ctx, dst); + break; + case GGML_OP_ACC: + ggml_cuda_op_acc(ctx, dst); + break; + case GGML_OP_MUL: + ggml_cuda_op_mul(ctx, dst); + break; + case GGML_OP_DIV: + ggml_cuda_op_div(ctx, dst); + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_ABS: + ggml_cuda_op_abs(ctx, dst); + break; + case GGML_UNARY_OP_SGN: + ggml_cuda_op_sgn(ctx, dst); + break; + case GGML_UNARY_OP_NEG: + ggml_cuda_op_neg(ctx, dst); + break; + case GGML_UNARY_OP_STEP: + ggml_cuda_op_step(ctx, dst); + break; + case GGML_UNARY_OP_GELU: + ggml_cuda_op_gelu(ctx, dst); + break; + case GGML_UNARY_OP_SILU: + ggml_cuda_op_silu(ctx, dst); + break; + case GGML_UNARY_OP_GELU_ERF: + ggml_cuda_op_gelu_erf(ctx, dst); + break; + case GGML_UNARY_OP_GELU_QUICK: + ggml_cuda_op_gelu_quick(ctx, dst); + break; + case GGML_UNARY_OP_TANH: + ggml_cuda_op_tanh(ctx, dst); + break; + case GGML_UNARY_OP_RELU: + ggml_cuda_op_relu(ctx, dst); + break; + case GGML_UNARY_OP_SIGMOID: + ggml_cuda_op_sigmoid(ctx, dst); + break; + case GGML_UNARY_OP_HARDSIGMOID: + ggml_cuda_op_hardsigmoid(ctx, dst); + break; + case GGML_UNARY_OP_HARDSWISH: + ggml_cuda_op_hardswish(ctx, dst); + break; + case GGML_UNARY_OP_EXP: + ggml_cuda_op_exp(ctx, dst); + break; + case GGML_UNARY_OP_ELU: + ggml_cuda_op_elu(ctx, dst); + break; + case GGML_UNARY_OP_XIELU: + ggml_cuda_op_xielu(ctx, dst); + break; + case GGML_UNARY_OP_FLOOR: + ggml_cuda_op_floor(ctx, dst); + break; + case GGML_UNARY_OP_CEIL: + ggml_cuda_op_ceil(ctx, dst); + break; + case GGML_UNARY_OP_ROUND: + ggml_cuda_op_round(ctx, dst); + break; + case GGML_UNARY_OP_TRUNC: + ggml_cuda_op_trunc(ctx, dst); + break; + case GGML_UNARY_OP_EXPM1: + ggml_cuda_op_expm1(ctx, dst); + break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_softplus(ctx, dst); + break; + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: + ggml_cuda_op_reglu(ctx, dst); + break; + case GGML_GLU_OP_GEGLU: + ggml_cuda_op_geglu(ctx, dst); + break; + case GGML_GLU_OP_SWIGLU: + ggml_cuda_op_swiglu(ctx, dst); + break; + case GGML_GLU_OP_SWIGLU_OAI: + ggml_cuda_op_swiglu_oai(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_ERF: + ggml_cuda_op_geglu_erf(ctx, dst); + break; + case GGML_GLU_OP_GEGLU_QUICK: + ggml_cuda_op_geglu_quick(ctx, dst); + break; + default: + return false; + } + break; + case GGML_OP_NORM: + ggml_cuda_op_norm(ctx, dst); + break; + case GGML_OP_GROUP_NORM: + ggml_cuda_op_group_norm(ctx, dst); + break; + case GGML_OP_L2_NORM: + ggml_cuda_op_l2_norm(ctx, dst); + break; + case GGML_OP_CONCAT: + ggml_cuda_op_concat(ctx, dst); + break; + case GGML_OP_UPSCALE: + ggml_cuda_op_upscale(ctx, dst); + break; + case GGML_OP_PAD: + ggml_cuda_op_pad(ctx, dst); + break; + case GGML_OP_PAD_REFLECT_1D: + ggml_cuda_op_pad_reflect_1d(ctx, dst); + break; + case GGML_OP_ARANGE: + ggml_cuda_op_arange(ctx, dst); + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_cuda_op_timestep_embedding(ctx, dst); + break; + case GGML_OP_LEAKY_RELU: + ggml_cuda_op_leaky_relu(ctx, dst); + break; + case GGML_OP_SILU_BACK: + ggml_cuda_op_silu_back(ctx, dst); + break; + case GGML_OP_RMS_NORM: + ggml_cuda_op_rms_norm(ctx, dst); + break; + case GGML_OP_RMS_NORM_BACK: + ggml_cuda_op_rms_norm_back(ctx, dst); + break; + case GGML_OP_MUL_MAT: + ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst); + break; + case GGML_OP_MUL_MAT_ID: + ggml_cuda_mul_mat_id(ctx, dst); + break; + case GGML_OP_OUT_PROD: + ggml_cuda_out_prod(ctx, dst); + break; + case GGML_OP_SCALE: + ggml_cuda_op_scale(ctx, dst); + break; + case GGML_OP_SQR: + ggml_cuda_op_sqr(ctx, dst); + break; + case GGML_OP_SQRT: + ggml_cuda_op_sqrt(ctx, dst); + break; + case GGML_OP_SIN: + ggml_cuda_op_sin(ctx, dst); + break; + case GGML_OP_COS: + ggml_cuda_op_cos(ctx, dst); + break; + case GGML_OP_CLAMP: + ggml_cuda_op_clamp(ctx, dst); + break; + case GGML_OP_LOG: + ggml_cuda_op_log(ctx, dst); + break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + break; + case GGML_OP_DIAG: + ggml_cuda_op_diag(ctx, dst); + break; + case GGML_OP_DIAG_MASK_INF: + ggml_cuda_op_diag_mask_inf(ctx, dst); + break; + case GGML_OP_SOFT_MAX: + ggml_cuda_op_soft_max(ctx, dst); + break; + case GGML_OP_SOFT_MAX_BACK: + ggml_cuda_op_soft_max_back(ctx, dst); + break; + case GGML_OP_ROPE: + ggml_cuda_op_rope(ctx, dst); + break; + case GGML_OP_ROPE_BACK: + ggml_cuda_op_rope_back(ctx, dst); + break; + case GGML_OP_ROLL: + ggml_cuda_op_roll(ctx, dst); + break; + case GGML_OP_IM2COL: + ggml_cuda_op_im2col(ctx, dst); + break; + case GGML_OP_IM2COL_3D: + ggml_cuda_op_im2col_3d(ctx, dst); + break; + case GGML_OP_CONV_2D: + ggml_cuda_op_conv2d(ctx, dst); + break; + case GGML_OP_CONV_2D_DW: + ggml_cuda_op_conv2d_dw(ctx, dst); + break; + case GGML_OP_CONV_TRANSPOSE_2D: + ggml_cuda_conv_2d_transpose_p0(ctx, dst); + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_cuda_op_conv_transpose_1d(ctx,dst); + break; + case GGML_OP_POOL_2D: + ggml_cuda_op_pool2d(ctx, dst); + break; + case GGML_OP_SUM: + ggml_cuda_op_sum(ctx, dst); + break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_SUM_ROWS: + ggml_cuda_op_sum_rows(ctx, dst); + break; + case GGML_OP_MEAN: + ggml_cuda_op_mean(ctx, dst); + break; + case GGML_OP_SSM_CONV: + ggml_cuda_op_ssm_conv(ctx, dst); + break; + case GGML_OP_SSM_SCAN: + ggml_cuda_op_ssm_scan(ctx, dst); + break; + case GGML_OP_TOP_K: + ggml_cuda_op_top_k(ctx, dst); + break; + case GGML_OP_ARGSORT: + ggml_cuda_op_argsort(ctx, dst); + break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); + break; + case GGML_OP_CROSS_ENTROPY_LOSS: + ggml_cuda_cross_entropy_loss(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; + case GGML_OP_RWKV_WKV6: + ggml_cuda_op_rwkv_wkv6(ctx, dst); + break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_cuda_op_gated_linear_attn(ctx, dst); + break; + case GGML_OP_RWKV_WKV7: + ggml_cuda_op_rwkv_wkv7(ctx, dst); + break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + ggml_cuda_cross_entropy_loss_back(ctx, dst); + break; + case GGML_OP_OPT_STEP_ADAMW: + ggml_cuda_opt_step_adamw(ctx, dst); + break; + case GGML_OP_OPT_STEP_SGD: + ggml_cuda_opt_step_sgd(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_cuda_op_solve_tri(ctx, dst); + break; + case GGML_OP_FILL: + ggml_cuda_op_fill(ctx, dst); + break; + default: + return false; + } + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + GGML_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst)); + CUDA_CHECK(err); + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +// backend + +static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + return cuda_ctx->name.c_str(); +} + +static void ggml_backend_cuda_free(ggml_backend_t backend) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + delete cuda_ctx; + delete backend; +} + +static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); +} + +static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { + ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; + ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; + + if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { + return false; + } + + if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + return false; + } + + // device -> device copy + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + + if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); +#endif + return false; + } + + if (backend_src != backend_dst) { + // copy on src stream + if (cuda_ctx_src->device == cuda_ctx_dst->device) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); + } else { +#ifdef GGML_CUDA_NO_PEER_COPY + return false; +#else + CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream())); +#endif + } + + // record event on src stream after the copy + if (!cuda_ctx_src->copy_event) { + ggml_cuda_set_device(cuda_ctx_src->device); + CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming)); + } + + CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream())); + + // wait on dst stream for the copy to complete + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0)); + } else { + // src and dst are on the same backend + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); + } + return true; +} + +static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream())); + + GGML_UNUSED(backend); +} + +#ifdef USE_CUDA_GRAPH +static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { + + bool use_cuda_graph = true; + // Loop over nodes in GGML graph to obtain info needed for CUDA graph + + const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; + const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; + const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; + const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; + const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; + const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) { + use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__); +#endif + } + + if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { + use_cuda_graph = false; // This node type is not supported by CUDA graph capture +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); +#endif + } + + if (node->op == GGML_OP_ADD && + node->src[1] && node->src[1]->ne[1] > 1 && + (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && + (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && + strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && + strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { + // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation + // by means of matching node names. See + // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and + // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, + // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. + use_cuda_graph = false; +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); +#endif + } + + if (!use_cuda_graph) { + break; + } + } + + return use_cuda_graph; +} + +static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { + memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); + props->node_data = node->data; + props->node_op = node->op; + props->node_type = node->type; + props->flags = node->flags; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + props->ne[i] = node->ne[i]; + props->nb[i] = node->nb[i]; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (!node->src[i]) { + continue; + } + + props->src_data[i] = node->src[i]->data; + } + memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); +} + +static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { + if (node->data != props->node_data && node->op != GGML_OP_VIEW) { + return false; + } + + if (node->op != props->node_op) { + return false; + } + + if (node->type != props->node_type) { + return false; + } + + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != props->ne[i]) { + return false; + } + if (node->nb[i] != props->nb[i]) { + return false; + } + } + + if (node->op != GGML_OP_VIEW) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (!node->src[i]) { + if (props->src_data[i] != nullptr) { + return false; + } + continue; + } + + if (node->src[i]->data != props->src_data[i]) { + return false; + } + } + } + + if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + return false; + } + + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { + return false; + } + + return true; +} + +static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { + return cgraph->nodes[0]; +} + +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { + bool res = false; + + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + + if (graph->instance == nullptr) { + res = true; + } + + // Check if the graph size has changed + if (graph->props.size() != (size_t)cgraph->n_nodes) { + res = true; + graph->props.resize(cgraph->n_nodes); + } + + // Loop over nodes in GGML graph to determine if CUDA graph update is required + // and store properties to allow this comparison for the next token + std::unordered_set<ggml_tensor *> seen_node; + std::vector<ggml_tensor *> srcs_extra; + for (int i = 0; i < cgraph->n_nodes; i++) { + bool props_match = true; + + seen_node.insert(cgraph->nodes[i]); + + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); + } + if (!props_match) { + res = true; + } + ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; + if (src && seen_node.find(src) == seen_node.end()) { + srcs_extra.push_back(src); + } + } + } + + if (graph->extra.size() != (size_t) srcs_extra.size()) { + res = true; + graph->extra.resize(srcs_extra.size()); + } + + for (size_t i = 0; i < srcs_extra.size(); ++i) { + bool props_match = true; + + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); + } + + if (!props_match) { + res = true; + } + ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); + } + + return res; +} + +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + +#if CUDART_VERSION >= 12000 + cudaGraphExecUpdateResultInfo result_info; + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info); +#else + cudaGraphNode_t errorNode; + cudaGraphExecUpdateResult result_info; + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info); +#endif // CUDART_VERSION >= 12000 + + if (stat == cudaErrorGraphExecUpdateFailure) { +#ifndef NDEBUG + GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__); +#endif + + // The pre-existing graph exec cannot be updated due to violated constraints + // so instead clear error and re-instantiate + (void)cudaGetLastError(); + CUDA_CHECK(cudaGraphExecDestroy(graph->instance)); + graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); + } else { + GGML_ASSERT(stat == cudaSuccess); + } +} +#endif // USE_CUDA_GRAPH + +static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, + const ggml_tensor * view, + const ggml_tensor * set_rows) { + + if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) { + return false; + } + // ne3 not tested + if (rope->src[0]->ne[3] != 1) { + return false; + } + + if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) { + return false; + } + + if (set_rows->src[1]->type != GGML_TYPE_I64) { + return false; + } + + // The view should flatten two dims of rope into one dim + if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) { + return false; + } + + // Only norm/neox shaders have the fusion code + const int mode = ((const int32_t *) rope->op_params)[2]; + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + return false; + } + + return true; +} + +static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) { + args.sigmoid = false; + args.softmax = false; + args.delayed_softmax = false; + args.prob_bias = false; + args.norm = false; + + const int n_nodes = cgraph->n_nodes; + ggml_tensor ** nodes = cgraph->nodes; + + if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) { + args.softmax = true; + } + + if (nodes[node_idx]->op == GGML_OP_UNARY) { + if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) { + return false; + } + args.sigmoid = true; + } + + if (nodes[node_idx]->op == GGML_OP_ARGSORT) { + args.delayed_softmax = true; + } + + node_idx++; + + if (args.sigmoid || args.softmax) { + // SOFTMAX -> RESHAPE + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx]; + node_idx++; + + if (node_idx >= n_nodes) { + return false; + } + + // src of bias add is the unreshaped probs (-2 instead of -1) + if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) { + args.prob_bias = true; + node_idx++; + } + // RESHAPE/ADD -> ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) { + return false; + } + + if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) { + return false; + } + + node_idx++; + + // ARGSORT-> VIEW + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) { + return false; + } + + // GET_ROWS + if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } else if (args.delayed_softmax) { + if (node_idx - 2 < 0) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx - 2]; + + // VIEW->ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + // GET_ROWS + if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != probs_reshaped) { + return false; + } + node_idx++; + + static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; + + for (const ggml_op op : remaining_ops) { + if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } + } + + // At this point we can check for norm + scale. Everything is now at least valid till the norm + if (node_idx >= n_nodes) { + return true; + } + + if (nodes[node_idx]->op == GGML_OP_RESHAPE) { + //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE + static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP }; + + args.norm = true; + for (const ggml_op op : norm_ops) { + if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + node_idx++; + } else { + args.norm = false; + return true; + } + } + + // DIV <- CLAMP, RESHAPE + if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != nodes[node_idx - 3]) { + args.norm = false; + return true; + } + node_idx++; + + if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + args.norm = false; + return true; + } + + node_idx++; + } + + if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + args.scale = true; + } + + return true; +} + +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, + int node_idx, + std::initializer_list<enum ggml_op> ops, + std::initializer_list<enum ggml_unary_op> unary_ops) { +#ifndef NDEBUG + const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); + GGML_ASSERT(unary_ops.size() == num_unary); +#endif + + const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1, + const std::initializer_list<enum ggml_op> & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + + std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; + std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; + + std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; + std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; + + if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) { + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; + const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { + return true; + } + } + + if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) && + ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { + const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; + const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; + const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; + + if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { + return true; + } + } + + std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }; + + if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { + const ggml_tensor * rope = cgraph->nodes[node_idx]; + const ggml_tensor * view = cgraph->nodes[node_idx + 1]; + const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2]; + + if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) { + return true; + } + } + + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = nullptr; + + if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) { + add = cgraph->nodes[node_idx+2]; + } + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + //rms norm only supports F32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + if (add && (add->src[0]->type != GGML_TYPE_F32 || + add->src[1]->type != GGML_TYPE_F32 || + add->type != GGML_TYPE_F32) ) { + return false; + } + + //if rms norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + + //rms_norm kernel assumes contigous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + + if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) { + return false; + } + + return true; + } + + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { + const ggml_tensor *scale = cgraph->nodes[node_idx]; + const ggml_tensor *tanh = cgraph->nodes[node_idx+1]; + const ggml_tensor *scale2 = cgraph->nodes[node_idx+2]; + + GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(scale->type == GGML_TYPE_F32); + + if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) { + return false; + } + + // Check for bias + if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) { + return false; + } + + return true; + } + + return false; +} + +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { + bool graph_evaluated_or_captured = false; + + // flag used to determine whether it is an integrated_gpu + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + + ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); + bool is_concurrent_event_active = false; + ggml_cuda_concurrent_event * concurrent_event = nullptr; + bool should_launch_concurrent_events = false; + + const auto try_launch_concurrent_event = [&](const ggml_tensor * node) { + if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { + concurrent_event = &stream_ctx.concurrent_events[node]; + + is_concurrent_event_active = true; + + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + + cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + GGML_ASSERT(cuda_ctx->curr_stream_no == 0); + CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); + + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + } + } + }; + + while (!graph_evaluated_or_captured) { + // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. + // With the use of CUDA graphs, the execution will be performed by the graph launch. + if (!use_cuda_graph || cuda_graph_update_required) { + [[maybe_unused]] int prev_i = 0; + + if (stream_ctx.concurrent_events.size() > 0) { + should_launch_concurrent_events = true; + for (const auto & [tensor, event] : stream_ctx.concurrent_events) { + should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); + } + } + + if (should_launch_concurrent_events) { + // Restore original node order within each concurrent region to enable fusion within streams + + std::unordered_map<const ggml_tensor *, int> node_to_idx; + node_to_idx.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + node_to_idx[cgraph->nodes[i]] = i; + } + + for (auto & [fork_node, event] : stream_ctx.concurrent_events) { + // Find positions of all nodes from this event in the current graph + std::vector<int> positions; + positions.reserve(event.original_order.size()); + + bool all_found = true; + for (const ggml_tensor * orig_node : event.original_order) { + auto it = node_to_idx.find(orig_node); + if (it != node_to_idx.end()) { + positions.push_back(it->second); + } else { + all_found = false; + break; + } + } + + if (!all_found || positions.size() != event.original_order.size()) { + continue; + } + + // Sort positions to get contiguous range + std::vector<int> sorted_positions = positions; + std::sort(sorted_positions.begin(), sorted_positions.end()); + + bool is_contiguous = true; + for (size_t i = 1; i < sorted_positions.size(); ++i) { + if (sorted_positions[i] != sorted_positions[i-1] + 1) { + is_contiguous = false; + break; + } + } + + if (!is_contiguous) { + continue; + } + + // Restore original order at the sorted positions + int start_pos = sorted_positions[0]; + for (size_t i = 0; i < event.original_order.size(); ++i) { + cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]); + } + } + } else { + stream_ctx.concurrent_events.clear(); + } + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (is_concurrent_event_active) { + GGML_ASSERT(concurrent_event); + + if (node == concurrent_event->join_node) { + cuda_ctx->curr_stream_no = 0; + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + // Wait on join events of forked streams in the main stream + CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1], + cuda_ctx->stream(cuda_ctx->device, i))); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1])); + } + + is_concurrent_event_active = false; + concurrent_event = nullptr; + } else { + GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end()); + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } else if (i - prev_i > 1) { + //the previous node was fused + const ggml_tensor * prev_node = cgraph->nodes[i - 1]; + try_launch_concurrent_event(prev_node); + + if (is_concurrent_event_active) { + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); + } + } + +#ifdef GGML_CUDA_DEBUG + const int nodes_fused = i - prev_i - 1; + if (nodes_fused > 0) { + GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); + } +#endif + prev_i = i; + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + + // start of fusion operations + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion) { + ggml_cuda_topk_moe_args args; + + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); + + std::vector<ggml_op> ops; + + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; + + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } + + if (args.norm) { + ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, + GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } + + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; + + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, + GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } + } + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { + ggml_tensor * rope = cgraph->nodes[i]; + ggml_tensor * set_rows = cgraph->nodes[i + 2]; + + ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); + i += 2; + continue; + } + + if (node->op == GGML_OP_ADD) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, GGML_OP_ADD); + + for (; n_fuse <= 6; ++n_fuse){ + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } + + n_fuse++; + + if (n_fuse > 1) { + for (int j = 0; j < n_fuse - 1; ++j) { + node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + cgraph->nodes[i + n_fuse - 1]->data = node->data; + ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + i += n_fuse - 1; + + continue; + } + } + + bool fused_mul_mat_vec = false; + int fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; + } + return (ggml_tensor *) nullptr; + } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + // we don't support repeating adds + if (bias_op == GGML_OP_ADD && + (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || + !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) + || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) continue; + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + i += fused_node_count - 1; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + i += 2; + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); + continue; + } + } +#ifndef NDEBUG + assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != nullptr) { + assert(node->src[j]->buffer); + assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) || + ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft))); + } + } +#else + GGML_UNUSED(integrated); +#endif // NDEBUG + + bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); + if (!ok) { + GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + + if (!is_concurrent_event_active) { + try_launch_concurrent_event(node); + } + } + } + +#ifdef USE_CUDA_GRAPH + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture + if (graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph->graph)); + graph->graph = nullptr; + } + + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph)); + graph_evaluated_or_captured = true; // CUDA graph has been captured + + std::lock_guard<std::mutex> lock(ggml_cuda_lock); + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { + ggml_cuda_lock_cv.notify_all(); + } + } else { + graph_evaluated_or_captured = true; // ggml graph has been directly evaluated + } + } + + if (use_cuda_graph) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); + } + if (cuda_graph_update_required) { // Update graph executable + ggml_cuda_graph_update_executable(cuda_ctx, graph_key); + } + // Launch graph + CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); +#else + GGML_UNUSED(graph_key); + graph_evaluated_or_captured = true; +#endif // USE_CUDA_GRAPH + } +} + +#ifdef USE_CUDA_GRAPH +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + + if (graph->graph == nullptr) { + if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { + if (!graph->disable_due_to_gpu_arch) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); + } + graph->disable_due_to_gpu_arch = true; + } + } + + return graph->is_enabled(); +} +#endif // USE_CUDA_GRAPH + +static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; + const void * graph_key = nullptr; + +#ifdef USE_CUDA_GRAPH + graph_key = ggml_cuda_graph_get_key(cgraph); + + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->is_enabled()) { + cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); + + graph->record_update(use_cuda_graph, cuda_graph_update_required); + } +#endif // USE_CUDA_GRAPH + + if (use_cuda_graph && cuda_graph_update_required) { + // Start CUDA graph capture + { + std::lock_guard<std::mutex> lock(ggml_cuda_lock); + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); + } + + CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); + } + + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + if (ggml_backend_is_cuda(backend)) { + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0)); + } else { +#if 0 + // untested + auto wait_fn = [](void * user_data) { + ggml_backend_event_t event = (ggml_backend_event_t)user_data; + ggml_backend_event_synchronize(event); + }; + + CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event)); +#endif + GGML_ABORT("fatal error"); + } +} + +static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + +#ifdef USE_CUDA_GRAPH + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); +#else + const bool use_cuda_graph = false; + GGML_UNUSED(cuda_ctx); + GGML_UNUSED(cgraph); +#endif + + static bool enable_graph_optimization = [] { + const char * env = getenv("GGML_CUDA_GRAPH_OPT"); + return env != nullptr && atoi(env) == 1; + }(); + + if (!enable_graph_optimization) { + return; + } + + ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); + stream_context.reset(); + + if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) { + return; + } + + // number of out-degrees for a particular node + std::unordered_map<const ggml_tensor *, int> fan_out; + // reverse mapping of node to index in the cgraph + std::unordered_map<const ggml_tensor *, int> node_indices; + + const auto & is_noop = [](const ggml_tensor * node) -> bool { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || + node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor * src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { + const ggml_tensor * node = cgraph->nodes[node_idx]; + node_indices[node] = node_idx; + + if (is_noop(node)) { + continue; + } + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx]; + //TODO: check why nrows > 1 fails + if (node && !is_noop(node) && ggml_nrows(node) <= 1) { + fan_out[src] += 1; + } + } + } + + // Target Q, K, V for concurrency + // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else): + // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm") + // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn") + // 3. account for all branches from the fork to the join + // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details) + // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams + // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030 + + const int min_fan_out = 3; + const int max_fan_out = 3; + + // store {fork_idx, join_idx} + std::vector<std::pair<int, int>> concurrent_node_ranges; + + for (const auto & [root_node, count] : fan_out) { + if (count >= min_fan_out && count <= max_fan_out) { + const int root_node_idx = node_indices[root_node]; + + // only optimize for attn_norm + // TODO: make this more generic + if (!strstr(root_node->name, "attn_norm")) { + continue; + } + + bool is_part_of_event = false; + for (const auto & [start, end] : concurrent_node_ranges) { + if (root_node_idx >= start && root_node_idx <= end) { + is_part_of_event = true; + } + } + + if (is_part_of_event) { + continue; + } + + std::vector<std::vector<const ggml_tensor *>> nodes_per_branch; + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * node = cgraph->nodes[i]; + if (!is_noop(node) && depends_on(node, root_node)) { + nodes_per_branch.push_back({ node }); + } + } + + GGML_ASSERT(nodes_per_branch.size() == (size_t) count); + + //find the join point + const ggml_tensor * join_node = nullptr; + + const auto & belongs_to_branch = [&](const ggml_tensor * node, + const std::vector<const ggml_tensor *> & branch) -> bool { + for (const ggml_tensor * n : branch) { + if (depends_on(node, n)) { + return true; + } + } + return false; + }; + + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * curr_node = cgraph->nodes[i]; + + int num_joins = 0; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) { + num_joins++; + } + } + + if (num_joins >= 2) { + join_node = curr_node; + break; + } + + bool found_branch = false; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx]; + if (belongs_to_branch(curr_node, branch_vec)) { + //continue accumulating + if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) { + branch_vec.push_back(curr_node); + } + found_branch = true; + } + } + + if (!found_branch && is_noop(curr_node)) { + // we can put it in any branch because it will be ignored + nodes_per_branch[0].push_back({ curr_node }); + } + } + + if (join_node) { + //Create ggml_cuda_concurrent_event + ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size()); + concurrent_event.join_node = join_node; + + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + for (const ggml_tensor * n : nodes_per_branch[branch_idx]) { + concurrent_event.stream_mapping[n] = branch_idx + 1; + } + } + + int fork_node_idx = node_indices[root_node]; + int join_node_idx = node_indices[join_node]; + + int current_branch_idx = 0; + int current_node_idx = fork_node_idx + 1; + const int n_branches = nodes_per_branch.size(); + + int total_branch_nodes = 0; + for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) { + total_branch_nodes += branch_nodes.size(); + } + + // there are other nodes in the middle which are unaccounted for + // usually (cpy) nodes, then ignore this fork + if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) { + GGML_LOG_DEBUG( + "Skipping %s because the number of nodes in the middle is not equal to the total number of " + "branch nodes %d != %d\n", + root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes); + continue; + } + + // Save the original order of nodes in this region before interleaving + // This is used later to restore grouping for fusion within streams + concurrent_event.original_order.reserve(total_branch_nodes); + for (int i = fork_node_idx + 1; i < join_node_idx; ++i) { + concurrent_event.original_order.push_back(cgraph->nodes[i]); + } + + std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events; + GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); + concurrent_events.emplace(root_node, std::move(concurrent_event)); + GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node); + concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx); + + // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them + // example transformation: + // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] -> + // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn] + while (current_node_idx < join_node_idx) { + std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx]; + + bool has_node = false; + for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) { + has_node |= branch_node.size() > 0; + } + + GGML_ASSERT(has_node); + + if (branch_nodes.empty()) { + current_branch_idx = (current_branch_idx + 1) % n_branches; + continue; + } + + cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + + // append all empty nodes + while (!branch_nodes.empty() && is_noop(branch_nodes.front())) { + cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + } + + current_branch_idx = (current_branch_idx + 1) % n_branches; + } + } + } + } +} + +static const ggml_backend_i ggml_backend_cuda_interface = { + /* .get_name = */ ggml_backend_cuda_get_name, + /* .free = */ ggml_backend_cuda_free, + /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, + /* .synchronize = */ ggml_backend_cuda_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_cuda_graph_compute, + /* .event_record = */ ggml_backend_cuda_event_record, + /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, +}; + +static ggml_guid_t ggml_backend_cuda_guid() { + static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 }; + return &guid; +} + +bool ggml_backend_is_cuda(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid()); +} + +int ggml_backend_cuda_get_device_count() { + return ggml_cuda_info().device_count; +} + +void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + snprintf(description, description_size, "%s", prop.name); +} + +void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) { + ggml_cuda_set_device(device); + + CUDA_CHECK(cudaMemGetInfo(free, total)); +} + +bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) { + if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) { + return false; + } + +#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP) + cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); + if (err != cudaSuccess) { + // clear the error + (void)cudaGetLastError(); + + GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__, + size / 1024.0 / 1024.0, cudaGetErrorString(err)); + return false; + } + return true; +#else + GGML_UNUSED(buffer); + GGML_UNUSED(size); + return false; +#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) +} + +void ggml_backend_cuda_unregister_host_buffer(void * buffer) { + if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) { + return; + } + + cudaError_t err = cudaHostUnregister(buffer); + if (err != cudaSuccess) { + // clear the error + (void)cudaGetLastError(); + } +} + + +// backend device + +struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; + std::string pci_bus_id; + int op_offload_min_batch_size; +}; + +static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ctx->description.c_str(); +} + +#if defined(__linux__) +// Helper function to get available memory from /proc/meminfo for UMA systems +static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) { + FILE * meminfo_file = nullptr; + // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough + const size_t BUFFER_SIZE = 2048; + auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE); + size_t bytes_read = 0; + long huge_tlb_total_pages = -1; + long huge_tlb_free_pages = -1; + long huge_tlb_page_size = -1; + + if (available_memory_kb == nullptr || free_swap_kb == nullptr) { + return false; + } + + meminfo_file = fopen("/proc/meminfo", "r"); + if (meminfo_file == nullptr) { + GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__); + return false; + } + + // Read file into buffer + bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file); + fclose(meminfo_file); + + if (bytes_read == 0) { + GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__); + return false; + } + file_buffer[bytes_read] = '\0'; + + *available_memory_kb = -1; + *free_swap_kb = -1; + + // Parse the file buffer line by line + char * line = file_buffer.get(); + char * line_next; + while (line < file_buffer.get() + bytes_read) { + // Find the end of the current line + line_next = strchr(line, '\n'); + if (line_next != nullptr) { + *line_next = '\0'; + line_next++; + } else { + line_next = file_buffer.get() + bytes_read; + } + + long value; + if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) { + *available_memory_kb = value; + } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) { + *free_swap_kb = value; + } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) { + huge_tlb_total_pages = value; + } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) { + huge_tlb_free_pages = value; + } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) { + huge_tlb_page_size = value; + } + + line = line_next; + } + + if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) { + *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size; + + // Hugetlbfs pages are not swappable. + *free_swap_kb = 0; + } + + GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb); + return true; +} +#endif // defined(__linux__) + +static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemGetInfo(free, total)); + +// ref: https://github.com/ggml-org/llama.cpp/pull/17368 +#if defined(__linux__) + // Check if this is a UMA (Unified Memory Architecture) system + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device)); + + // Check if UMA is explicitly enabled via environment variable + bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr; + bool is_uma = prop.integrated > 0 || uma_env; + + if (is_uma) { + // For UMA systems (like DGX Spark), use system memory info + long available_memory_kb = 0; + long free_swap_kb = 0; + + if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) { + *free = (size_t)available_memory_kb * 1024; + } else { + GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__); + } + } +#endif // defined(__linux__) + +} + +static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + + props->name = ggml_backend_cuda_device_get_name(dev); + props->description = ggml_backend_cuda_device_get_description(dev); + props->type = ggml_backend_cuda_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); + ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); + + bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; +#ifdef GGML_CUDA_NO_PEER_COPY + bool events = false; +#else + bool events = true; +#endif + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ host_buffer, + /* .buffer_from_host_ptr = */ false, + /* .events = */ events, + }; +} + +static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ggml_backend_cuda_init(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + return ggml_backend_cuda_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return ggml_backend_cuda_host_buffer_type(); +} + +// TODO: move these functions here +static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + + // split buffers can only be used with GGML_OP_MUL_MAT + if (op->op != GGML_OP_MUL_MAT) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) { + return false; + } + } + } + + // check if all the sources are allocated on this device + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) { + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context; + if (buft_ctx->device != dev_ctx->device) { + return false; + } + } + } + + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_ABS: + case GGML_UNARY_OP_SGN: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_XIELU: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: + return ggml_is_contiguous(op->src[0]); + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]); + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + struct ggml_tensor * a = op->src[0]; + struct ggml_tensor * b = op->src[1]; + if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) { + if (a->ne[2] > 1 || a->ne[3] > 1) { + return false; + } + // for small weight matrices the active device can end up without any rows, don't use row split in those cases + // this avoids some edge cases (and the performance would not be good anyways) + ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context; + int64_t row_low; + int64_t row_high; + get_row_split(&row_low, &row_high, a, buft_ctx->tensor_split, dev_ctx->device); + if (row_low == row_high) { + return false; + } + } + if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { + return false; + } +#ifdef GGML_USE_MUSA + const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; + if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) { + if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT && + a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) { + return false; + } + if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID && + a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) { + return false; + } + } +#endif // GGML_USE_MUSA + switch (a->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_K: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_BF16: + return true; + default: + return false; + } + } break; + case GGML_OP_OUT_PROD: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + case GGML_TYPE_BF16: + case GGML_TYPE_I32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return true; + default: + return false; + } + } break; + case GGML_OP_GET_ROWS_BACK: + { + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; + } break; + case GGML_OP_SET_ROWS: + { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || + op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || + op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && + op->src[0]->type == GGML_TYPE_F32 && + (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); + } break; + case GGML_OP_SET: + { + const ggml_type t = op->type; + return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) && + t == op->src[0]->type && + t == op->src[1]->type; + } break; + case GGML_OP_CPY: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) && + (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16) + ) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { + return true; + } + if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) { + return true; + } + if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) { + return true; + } + if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) { + return true; + } + if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) { + return true; + } + if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { + return true; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) { + return true; + } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) { + return true; + } + if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) { + return true; + } + if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { + return true; + } + return false; + } break; + case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + { + return true; + } break; + case GGML_OP_REPEAT: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15); + case GGML_OP_CONCAT: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1]->type; + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return true; + } + return false; + } break; + case GGML_OP_SILU_BACK: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: + return true; + case GGML_OP_RMS_NORM_BACK: + return ggml_is_contiguous(op->src[0]); + break; + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_ADD: + case GGML_OP_ADD_ID: + case GGML_OP_ADD1: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_LOG: + return true; + case GGML_OP_SSM_SCAN: { + if (op->src[3]->ne[0] == 1) { + // Mamba2 + // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0) + return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0; + } else { + // Mamba + // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1) + return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1; + } + } + case GGML_OP_SSM_CONV: { + // assumes d_inner % threads == 0 + return op->src[0]->ne[1] % 128 == 0; + } + case GGML_OP_CONT: + return true; + case GGML_OP_DIAG_MASK_INF: + return true; + case GGML_OP_SOFT_MAX: + return true; + case GGML_OP_SOFT_MAX_BACK: { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); + return max_bias == 0.0f; + } + case GGML_OP_ROLL: + if(op->src[0]->type == GGML_TYPE_F32) { + return true; + } + return false; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: { + return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); + } + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_POOL_2D: + case GGML_OP_ACC: + return true; + case GGML_OP_SUM: + return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_TOP_K: + case GGML_OP_ARGSORT: +#ifndef GGML_CUDA_USE_CUB + return op->src[0]->ne[0] <= 1024; +#else + return true; +#endif + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_PAD: + return true; + case GGML_OP_UPSCALE: + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_LEAKY_RELU: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: + return true; + case GGML_OP_FLASH_ATTN_EXT: + return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); + case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + case GGML_OP_FILL: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: + case GGML_OP_DIAG: + case GGML_OP_SOLVE_TRI: + return true; + + default: + return false; + } +} + +static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated; + return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft))); +} + +static int64_t get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return 0; + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + +static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + + return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size; +} + +static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) { +#ifdef GGML_CUDA_NO_PEER_COPY + return nullptr; +#else + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context; + + ggml_cuda_set_device(dev_ctx->device); + + cudaEvent_t event; + CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + + return new ggml_backend_event { + /* .device = */ dev, + /* .context = */ event, + }; +#endif +} + +static void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + + CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context)); + delete event; +} + +static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); +} + +static const ggml_backend_device_i ggml_backend_cuda_device_interface = { + /* .get_name = */ ggml_backend_cuda_device_get_name, + /* .get_description = */ ggml_backend_cuda_device_get_description, + /* .get_memory = */ ggml_backend_cuda_device_get_memory, + /* .get_type = */ ggml_backend_cuda_device_get_type, + /* .get_props = */ ggml_backend_cuda_device_get_props, + /* .init_backend = */ ggml_backend_cuda_device_init_backend, + /* .get_buffer_type = */ ggml_backend_cuda_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_cuda_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_cuda_device_supports_op, + /* .supports_buft = */ ggml_backend_cuda_device_supports_buft, + /* .offload_op = */ ggml_backend_cuda_device_offload_op, + /* .event_new = */ ggml_backend_cuda_device_event_new, + /* .event_free = */ ggml_backend_cuda_device_event_free, + /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, +}; + +// backend reg + +struct ggml_backend_cuda_reg_context { + std::vector<ggml_backend_dev_t> devices; +}; + +static const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return GGML_CUDA_NAME; +} + +static size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) { + ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context; + return ctx->devices.size(); +} + +static ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t reg) { + static std::vector<ggml_backend_feature> features = []() { + std::vector<ggml_backend_feature> features; + #define _STRINGIFY(...) #__VA_ARGS__ + #define STRINGIFY(...) _STRINGIFY(__VA_ARGS__) + + #ifdef __CUDA_ARCH_LIST__ + features.push_back({ "ARCHS", STRINGIFY(__CUDA_ARCH_LIST__) }); + #endif + + #ifdef GGML_CUDA_FORCE_MMQ + features.push_back({ "FORCE_MMQ", "1" }); + #endif + + #ifdef GGML_CUDA_FORCE_CUBLAS + features.push_back({ "FORCE_CUBLAS", "1" }); + #endif + + #ifndef GGML_USE_VMM + features.push_back({ "NO_VMM", "1" }); + #endif + + #ifdef GGML_CUDA_NO_PEER_COPY + features.push_back({ "NO_PEER_COPY", "1" }); + #endif + + #ifdef GGML_CUDA_USE_GRAPHS + features.push_back({ "USE_GRAPHS", "1" }); + #endif + + #ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE + features.push_back({ "PEER_MAX_BATCH_SIZE", STRINGIFY(GGML_CUDA_PEER_MAX_BATCH_SIZE) }); + #endif + + #ifdef GGML_CUDA_FA_ALL_QUANTS + features.push_back({ "FA_ALL_QUANTS", "1" }); + #endif + + { + const auto & info = ggml_cuda_info(); + for (int id = 0; id < info.device_count; ++id) { + if (blackwell_mma_available(info.devices[id].cc)) { + features.push_back({ "BLACKWELL_NATIVE_FP4", "1"}); + break; + } + } + } + + #undef _STRINGIFY + #undef STRINGIFY + + features.push_back({ nullptr, nullptr }); + + return features; + }(); + + return features.data(); + + GGML_UNUSED(reg); +} + +static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { + GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { + return (void *)ggml_backend_cuda_split_buffer_type; + } + if (strcmp(name, "ggml_backend_register_host_buffer") == 0) { + return (void *)ggml_backend_cuda_register_host_buffer; + } + if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) { + return (void *)ggml_backend_cuda_unregister_host_buffer; + } + if (strcmp(name, "ggml_backend_get_features") == 0) { + return (void *)ggml_backend_cuda_get_features; + } + return nullptr; +} + +static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = { + /* .get_name = */ ggml_backend_cuda_reg_get_name, + /* .get_device_count = */ ggml_backend_cuda_reg_get_device_count, + /* .get_device = */ ggml_backend_cuda_reg_get_device, + /* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address, +}; + +// backend registry +ggml_backend_reg_t ggml_backend_cuda_reg() { + static ggml_backend_reg reg; + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + if (!initialized) { + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; + + for (int i = 0; i < ggml_cuda_info().device_count; i++) { + ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; + dev_ctx->device = i; + dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); + dev_ctx->description = prop.name; + + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + dev_ctx->pci_bus_id = pci_bus_id; + dev_ctx->op_offload_min_batch_size = min_batch_size; + + ggml_backend_dev_t dev = new ggml_backend_device { + /* .iface = */ ggml_backend_cuda_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx + }; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_cuda_reg_interface, + /* .context = */ ctx + }; + } + + initialized = true; + } + + return ® +} + +ggml_backend_t ggml_backend_cuda_init(int device) { + if (device < 0 || device >= ggml_backend_cuda_get_device_count()) { + GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device); + return nullptr; + } + + ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device); + if (ctx == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate context\n", __func__); + return nullptr; + } + + ggml_backend_t cuda_backend = new ggml_backend { + /* .guid = */ ggml_backend_cuda_guid(), + /* .iface = */ ggml_backend_cuda_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .context = */ ctx, + }; + + return cuda_backend; +} + +GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg) diff --git a/llama.cpp/ggml/src/ggml-cuda/gla.cu b/llama.cpp/ggml/src/ggml-cuda/gla.cu new file mode 100644 index 0000000..f7d615a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/gla.cu @@ -0,0 +1,93 @@ +#include "common.cuh" +#include "gla.cuh" + +template<int HEAD_SIZE> +static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale, + const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = HEAD_SIZE; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _k[head_size], _r[head_size], _td[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + __syncthreads(); + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4 & k = (float4 &)(_k[j]); + const float4 & r = (float4 &)(_r[j]); + const float4 & td = (float4 &)(_td[j]); + float4 & s = (float4 &)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + s.x = s.x * td.x + kv.x; + s.y = s.y * td.y + kv.y; + s.z = s.z * td.z + kv.z; + s.w = s.w * td.w + kv.w; + + y += r.x * s.x; + y += r.y * s.y; + y += r.z * s.z; + y += r.w * s.w; + } + dst[t] = y * scale; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * k_d = (const float *)dst->src[0]->data; + const float * v_d = (const float *)dst->src[1]->data; + const float * r_d = (const float *)dst->src[2]->data; + const float * td_d = (const float *)dst->src[3]->data; + const float * s_d = (const float *)dst->src[4]->data; + + const int64_t B = dst->src[4]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float scale; + memcpy(&scale, (float*)dst->op_params, sizeof(float)); + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == 64 || C / H == 128); + + + if (C / H == 64) { + gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } else { + gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/gla.cuh b/llama.cpp/ggml/src/ggml-cuda/gla.cuh new file mode 100644 index 0000000..2c82ad7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/gla.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/im2col.cu b/llama.cpp/ggml/src/ggml-cuda/im2col.cu new file mode 100644 index 0000000..56dc054 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/im2col.cu @@ -0,0 +1,264 @@ +#include "im2col.cuh" + +#define MAX_GRIDDIM_Z 65535 + +template <typename T> +static __global__ void im2col_kernel( + const float * x, T * dst, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW, + int s0, int s1, int p0, int p1, int d0, int d1) { + const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= IC_KH_KW) { + return; + } + + const int64_t iic = i / (KH_KW); + const int64_t rem = i - iic * KH_KW; + const int64_t ikh = rem / KW; + const int64_t ikw = rem - ikh * KW; + + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } + } + + GGML_UNUSED(IC); + GGML_UNUSED(KH); +} + +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +template <typename T> +static void im2col_cuda(const float * x, T* dst, + int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, + int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + const int64_t N_OH = N * OH; + const int64_t KH_KW = KW*KH; + dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH, + IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, + s0, s1, p0, p1, d0, d1); +} + +static void im2col_cuda_f16(const float * x, half * dst, + int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, + int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { + + im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); +} + +static void im2col_cuda_f32(const float * x, float * dst, + int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC, + int64_t N, int64_t IC_IH_IW, int64_t IH_IW, + int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) { + + im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); +} + +void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; + + const int64_t IC = src1->ne[is_2D ? 2 : 1]; + const int64_t IH = is_2D ? src1->ne[1] : 1; + const int64_t IW = src1->ne[0]; + + const int64_t KH = is_2D ? src0->ne[1] : 1; + const int64_t KW = src0->ne[0]; + + const int64_t OH = is_2D ? dst->ne[2] : 1; + const int64_t OW = dst->ne[1]; + + const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + if(dst->type == GGML_TYPE_F16) { + im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); + } else { + im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template <typename T> +static __global__ void im2col_3d_kernel( + const float * src, T * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { + const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i >= IC_KD_KH_KW) { + return; + } + GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH); + GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW); + + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t ikw = i % KW; + + const int64_t iow = blockIdx.y; + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; + + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template <typename T> +static void im2col_3d_cuda(const float * src, T* dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t ID_IH_IW = ID*IH*IW; + const int64_t KH_KW = KH*KW; + const int64_t IH_IW = IH*IW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + const int64_t OW_KD_KH_KW = OW*KD*KH*KW; + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; + dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, + IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, + OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2); +} + +static void im2col_3d_cuda_f16(const float * src, half * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +static void im2col_3d_cuda_f32(const float * src, float * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { + + im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const size_t es = ggml_element_size(src1); + const int64_t stride_x = src1->nb[0] / es; + const int64_t stride_y = src1->nb[1] / es; + const int64_t stride_z = src1->nb[2] / es; + const int64_t stride_q = src1->nb[3] / es; + + if(dst->type == GGML_TYPE_F16) { + im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } else { + im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/im2col.cuh b/llama.cpp/ggml/src/ggml-cuda/im2col.cuh new file mode 100644 index 0000000..2da1223 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/im2col.cuh @@ -0,0 +1,6 @@ +#include "common.cuh" + +#define CUDA_IM2COL_BLOCK_SIZE 256 + +void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/mean.cu b/llama.cpp/ggml/src/ggml-cuda/mean.cu new file mode 100644 index 0000000..49af538 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mean.cu @@ -0,0 +1,75 @@ +#include "mean.cuh" +#include "reduce_rows.cuh" + +#ifdef GGML_CUDA_USE_CUB +#include <cub/cub.cuh> +using namespace cub; +#endif // GGML_CUDA_USE_CUB + +template <typename T> __global__ void divide_by_count(T * result, size_t count) { + *result /= static_cast<T>(count); +} + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + +// Special case for reducing vectors +#ifdef GGML_CUDA_USE_CUB +#ifdef USE_CUDA_GRAPH + cudaStreamCaptureStatus iscapturing; + CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing)); +#endif // USE_CUDA_GRAPH + if ((nrows == 1) && +#ifdef USE_CUDA_GRAPH + // Determine if CUDA graphs are effectively disabled for this context + // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled) + (((ncols > 65536) && + (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())) || + // CUDA graphs are enabled - use lower threshold + ((ncols > 32768) && + !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())))) { +#else + (ncols > 65536)) { +#endif // USE_CUDA_GRAPH + // Single row - use device-wide reduction + size_t tmp_size = 0; + ggml_cuda_pool & pool = ctx.pool(); + + DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream); + + ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size); + DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream); + + // Divide by ncols + divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols); + return; + } +#endif // GGML_CUDA_USE_CUB + + const dim3 block_nums(nrows, 1, 1); + + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + + // Heuristic for block size selection to optimize occupancy. + // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132 + if ((nrows / nsm) < 2) { + const dim3 block_dims(512, 1, 1); + reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + } else { + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/mean.cuh b/llama.cpp/ggml/src/ggml-cuda/mean.cuh new file mode 100644 index 0000000..2b9b104 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mean.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/mma.cuh b/llama.cpp/ggml/src/ggml-cuda/mma.cuh new file mode 100644 index 0000000..dd45d6c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mma.cuh @@ -0,0 +1,1381 @@ +#pragma once +// This file contains primitives that expose the tensor core PTX instructions for CUDA code. +// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. +// The documentation for the PTX instructions can be found under: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction +// +// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. +// A is a row-major matrix with shape M x K. +// B is a column-major matrix with shape K x N. +// C is a column-major matrix with shape M x N. +// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns. +// Note that J is measured in physical 32 bit elements instead of logical elements. +// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. +// All matrix tiles have ne physical 32 bit elements per warp. +// +// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. +// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior. + +#include "common.cuh" + +// On Volta each warp is doing 4 8x8 mma operations in parallel. +// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile. +// However, the i indices in this file are by default permuted to simplify the index calculations. +// #define GGML_CUDA_MMA_NO_VOLTA_PERM + +#if CUDART_VERSION >= 11080 + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + int ret = 0; + +#ifdef TURING_MMA_AVAILABLE + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" + : "=r"(ret) : "r"(x)); +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // defined(TURING_MMA_AVAILABLE) + return ret; +} + +#else + +static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { + // Imagine transposing row-major matrix to column-major matrix. + const int src_i_low = 2 * (threadIdx.x % 4); + const int src_i_high = src_i_low + 1; + const int src_j = threadIdx.x / 4; + + const int src_laneid_low = src_i_low * 4 + src_j / 2; + const int src_laneid_high = src_i_high * 4 + src_j / 2; + + const int shift_low = ((src_j + 0) % 2) * 16; + const int shift_high = ((src_j + 1) % 2) * 16; + + const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF; + const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; + + return ret_low | ret_high; +} + +#endif // CUDART_VERSION >= 11080 + +static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) { + half2 ret; + *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x)); + return ret; +} + +namespace ggml_cuda_mma { + + // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel, + // effectively the warp is being split into subgroups of threads that each perform a single mma instruction. + // In those cases the data can be split in different ways across the warp. + enum data_layout { + // By default the data uses the I direction as its major dimension and the J direction as its minor dimension. + // For the A/C matrices this means I major == row major, J major == column major. + // For the B matrix this means I major == column major, J major == row major. + // MIRRORED == Each data value is held exactly once per thread subgroup. + DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA. + DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. + DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. + DATA_LAYOUT_J_MAJOR_MIRRORED = 30, + }; + // Implemented mma combinations are: + // - (I_MAJOR, I_MAJOR) -> I_MAJOR + // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR + // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR + + static constexpr bool is_i_major(const data_layout dl) { + return dl == DATA_LAYOUT_I_MAJOR || + dl == DATA_LAYOUT_I_MAJOR_MIRRORED; + } + + static constexpr __device__ data_layout get_input_data_layout() { +#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + return DATA_LAYOUT_I_MAJOR_MIRRORED; +#else + return DATA_LAYOUT_I_MAJOR; +#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + } + + template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR> + struct tile {}; + + template <int I_, int J_, typename T> + struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + +#if defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = I * J / 64; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 64 && J == 2) return true; + if (I == 16 && J == 8) return true; + if (I == 32 && J == 4) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 32) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 4) { + return threadIdx.x % 32; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 32) { + return threadIdx.x % 32; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return (2 * ((threadIdx.x / 16) % 2) + l); + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 4) { + return 2 * (threadIdx.x / 32) + l; + } else if constexpr (I == 16 && J == 16) { + return 4 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 32) { + return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } +#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 32 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 32 && J == 8) { +#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2); +#else + return (l & 2) + (threadIdx.x & ~2); +#endif // GGML_CUDA_MMA_NO_VOLTA_PERM + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 32 && J == 8) { + return (threadIdx.x & 2) + (l & (4 + 1)); + } else { + NO_DEVICE_CODE; + return -1; + } + } +#elif defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (supported()) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 16) { +#if defined(RDNA3) + if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) { + // matrix C + return 2 * l + (threadIdx.x / 16); + } else { + // matrix A&B + return l; + } +#else + // matrix C is the transposed matrix A&B on RDNA4 + return ne * (threadIdx.x / 16) + l; +#endif // defined(RDNA3) + } else if constexpr (I == 16 && J == 8) { + // mmq input for RDNA4 + return ne * (threadIdx.x / 16) + l; + } else if constexpr (I == 16 && J == 4) { + return ne * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#else + static constexpr int ne = I * J / 32; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + if (I == 8 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x / 4; + } else if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 8) { + return ((l / 2) * 8) + (threadIdx.x / 4); + } else if constexpr (I == 16 && J == 16) { + return (((l / 2) % 2) * 8) + (threadIdx.x / 4); + } else if constexpr (I == 32 && J == 8) { + return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction. + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 8 && J == 8) { + return (l * 4) + (threadIdx.x % 4); + } else if constexpr (I == 16 && J == 8) { + return ((threadIdx.x % 4) * 2) + (l % 2); + } else if constexpr (I == 16 && J == 16) { + return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2); + } else if constexpr (I == 32 && J == 8) { + return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction. + } else { + NO_DEVICE_CODE; + return -1; + } + } +#endif // defined(GGML_USE_HIP) + }; + + template <int I_, int J_> + struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 32 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 32 && J == 4) { +#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM + return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); +#else + return threadIdx.x; +#endif // GGML_CUDA_MMA_NO_VOLTA_PERM + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 32 && J == 4) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#elif defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = I * J / 32; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return ne * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = I * J / 64; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return ne * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#else + static constexpr int ne = I * J / WARP_SIZE; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + if (I == 8 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return (l * 8) + (threadIdx.x / 4); + } else if constexpr (I == 16 && J == 8) { + return ((l % 2) * 8) + (threadIdx.x / 4); + } else if constexpr (I == 32 && J == 8) { + return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return (l * 4) + (threadIdx.x % 4); + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else if constexpr (I == 32 && J == 8) { + return ((l & 2) * 2) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + }; + + template <int I_, int J_> + struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; + +#if defined(AMD_WMMA_AVAILABLE) + static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l); + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l); + } +#else + static constexpr int ne = I * J / WARP_SIZE; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 8) return true; + if (I == 16 && J == 4) return true; + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 8) { + return threadIdx.x / 4; + } else if constexpr (I == 16 && J == 4) { + return (l * 8) + (threadIdx.x / 4); + } else if constexpr (I == 16 && J == 8) { + return ((l % 2) * 8) + (threadIdx.x / 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 8) { + return (l * 4) + (threadIdx.x % 4); + } else if constexpr (I == 16 && J == 4) { + return threadIdx.x % 4; + } else if constexpr (I == 16 && J == 8) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } +#endif // defined(AMD_WMMA_AVAILABLE) + }; + + template <int I_, int J_, typename T> + struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR; + + static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne; + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l); + } + }; + + template <int I_, int J_, typename T> + struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + + // RDNA3 + static constexpr int ne = I * J / 32 * 2; + + T x[ne] = {0}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (supported()) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (supported()) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + + template <int I_, int J_> + struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; +#if defined(RDNA3) + static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne; + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l); + } +#else // Volta + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int /*l*/) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#endif // defined(RDNA3) + }; + + template <int I_, int J_> + struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED; + static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne; + + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l); + } + }; + + template <int I_, int J_> + struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED; + static constexpr int ne = I * J / (WARP_SIZE/4); + + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 8 && J == 4) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 8 && J == 4) { + return ((l / 2) * 4) + (threadIdx.x % 4); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 8 && J == 4) { + return ((threadIdx.x / 16) * 2) + (l % 2); + } else { + NO_DEVICE_CODE; + return -1; + } + } + }; + +#if defined(TURING_MMA_AVAILABLE) + template <int I, int J> + static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { + tile<I, J/2, half2> ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + tile<8, 8, half2> ret; + ret.x[0] = ggml_cuda_movmatrix(t.x[0]); + ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + + return ret; + } +#elif defined(AMD_WMMA_AVAILABLE) + template <int I, int J> + static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { + tile<I, J/2, half2> ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + NO_DEVICE_CODE; + return tile<8, 8, half2>{}; + } +#else // Volta + template <int I, int J> + static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { + tile<I, J/2, half2> ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 4) { + ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]); + + // On Volta FP16 and FP32 tiles have a different memory layout, + // for the conversion threads with an offset of 2 need to exchange half their values: + ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync( + 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE); + } + return ret; + } +#endif // defined(TURING_MMA_AVAILABLE) + + static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { +#if defined(RDNA4) + const int row = t.get_i(0); + const int left_right = t.get_j(0) / 4; + const int up_down = row / 8; + const int idx = row % 8; + reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; +#else + GGML_UNUSED_VARS(t); + NO_DEVICE_CODE; +#endif // defined(RDNA4) + } + + template <int I, int J, typename T, data_layout dl> + static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) { +#if defined(AMD_MFMA_AVAILABLE) + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } else { + ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); + } +#elif defined(AMD_WMMA_AVAILABLE) + // All wmma layout has contiguous data when i-major. + if constexpr (is_i_major(dl)) { + // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() + constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); + if constexpr (sizeof(t.x) > aligned_copy_bytes) { + static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); + constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; +#pragma unroll + for (int i = 0; i < aligned_copy_count; ++i) { + ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); + } + } else { + ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); + } + } else { +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } +#else +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } +#endif // defined(AMD_MFMA_AVAILABLE) + } + + template <typename T> + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef TURING_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else + load_generic(t, xs0, stride); +#endif // TURING_MMA_AVAILABLE + } + + template <typename T> + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef TURING_MMA_AVAILABLE + int * xi = (int *) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; + asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" + : "=r"(xi[0]), "=r"(xi[1]) + : "l"(xs)); +#else +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; +#else + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // TURING_MMA_AVAILABLE + } + + template <typename T, data_layout dl> + static __device__ __forceinline__ void load_ldmatrix( + tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) { +#if defined(TURING_MMA_AVAILABLE) + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) + : "l"(xs)); +#else +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if 1 + // TODO: more generic handling + static_assert(sizeof(T) == 4, "bad type size"); + ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#else + load_generic(t, xs0, stride); +#endif // 1 +#else + load_generic(t, xs0, stride); +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) { +#pragma unroll + for (int l0 = 0; l0 < t.ne; l0 += 2) { + ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0)); + } + } + + static __device__ __forceinline__ void load_ldmatrix( + tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); +#else + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + } + + template <typename T> + static __device__ __forceinline__ void load_ldmatrix_trans( + tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { +#ifdef TURING_MMA_AVAILABLE + int * xi = (int * ) t.x; + const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) + : "l"(xs)); +#else + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) { +#ifdef TURING_MMA_AVAILABLE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) { +#ifdef TURING_MMA_AVAILABLE +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1])); +#else + // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead: + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[0]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[1]), "r"(B.x[0])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[0]), "+r"(D.x[1]) + : "r"(A.x[2]), "r"(B.x[1])); + asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};" + : "+r"(D.x[2]), "+r"(D.x[3]) + : "r"(A.x[3]), "r"(B.x[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef TURING_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) { +#ifdef TURING_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[0]), "+r"(Dxi[1]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" + : "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(RDNA4) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + + template <data_layout dl_ab, data_layout dl_d> + static __device__ __forceinline__ void mma( + tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } + + template <data_layout dl_ab, data_layout dl_d> + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) { +#ifdef AMD_MFMA_AVAILABLE + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]); +#if defined(CDNA3) + using floatx2_t = __attribute__((ext_vector_type(2))) float; + const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]); + const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + + static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, + const tile<16, 8, int> & A, + const tile<8, 8, int> & B, + uint32_t a_scale, + uint32_t b_scale) { +#ifdef BLACKWELL_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + float * Dxi = (float *) D.x; + + asm volatile( + "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); +#else + GGML_UNUSED_VARS(D, A, B, a_scale, b_scale); +#endif // BLACKWELL_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) { +#ifdef TURING_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) { +#ifdef AMPERE_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMPERE_MMA_AVAILABLE + } + + template <data_layout dl_ab, data_layout dl_d> + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) { +#ifdef TURING_MMA_AVAILABLE + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; +#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3])); +#else + // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead: + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1])); + asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" + : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#elif defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); + const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]); + const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // RDNA4 +#elif defined(AMD_MFMA_AVAILABLE) + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]); + const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + + template <data_layout dl_ab, data_layout dl_d> + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) { +#if defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); + const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]); + const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#elif defined(RDNA3) + using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); + const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]); + const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]); +#if defined(CDNA3) || defined(CDNA2) + using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; + const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]); + const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16; + const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]); + const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) || defined(CDNA2) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_WMMA_AVAILABLE) + } + + template <data_layout dl_d, data_layout dl_ab> + static __device__ __forceinline__ void mma( + tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) + +#elif defined(AMD_WMMA_AVAILABLE) + + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + +#if defined(RDNA4) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); + +#elif defined(RDNA3) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * a_vec = (int32x4_t *) A.x; + int32x4_t * b_vec = (int32x4_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + true + ); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[1], + true, + b_vec[1], + acc[0], + true + ); +#endif // RDNA4 + +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; + int32x16_t * acc = (int32x16_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) + +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + + template <typename T1, typename T2, int J, int K> + static __device__ __forceinline__ void mma( + tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) { + tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D); + const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A); + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); + } + + static __device__ __forceinline__ void mma( + tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + + static __device__ __forceinline__ void mma( + tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { +#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + const int * Axi = (const int *) A.x; + const int * Bxi = (const int *) B.x; + int * Dxi = (int *) D.x; + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1])); + asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};" + : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]) + : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3])); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + + template <data_layout dl_d, data_layout dl_ab> + static __device__ __forceinline__ void mma( + tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { +#if defined(AMD_WMMA_AVAILABLE) + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; +#if defined(RDNA4) + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#elif defined(RDNA3) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * a_vec = (int32x4_t *) A.x; + int32x4_t * b_vec = (int32x4_t *) B.x; + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a_vec[0], + true, + b_vec[0], + acc[0], + false + ); +#endif // RDNA4 +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMD_WMMA_AVAILABLE + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/mmf.cu b/llama.cpp/ggml/src/ggml-cuda/mmf.cu new file mode 100644 index 0000000..aad4c34 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmf.cu @@ -0,0 +1,191 @@ +#include "ggml.h" +#include "mmf.cuh" +#include "mmid.cuh" + +static __forceinline__ int mmf_get_rows_per_block(const int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return MMF_ROWS_PER_BLOCK_CDNA; + } else { + return MMF_ROWS_PER_BLOCK; + } +} + +void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + + GGML_TENSOR_BINARY_OP_LOCALS; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(ne13 == ne3); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s13 = src1->nb[3] / ts_src1; + const int64_t s3 = dst->nb[3] / ts_dst; + + const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; + const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + + mmf_ids_data ids_info{}; + mmf_ids_data * ids_info_ptr = nullptr; + ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev; + ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev; + ggml_cuda_pool_alloc<int32_t> expert_bounds_dev; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_dst = ids ? ne1 : ne2; + + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + + int64_t stride_channel_y = ids ? s11 : s12; + int64_t nchannels_y = ids ? ne11 : ne12; + + //mul_mat_id: handle broadcast + if (ids && nchannels_y == 1) { + stride_channel_y = 0; + nchannels_y = ids->ne[0]; + } + + if (ids && ncols_dst > 16) { + const int64_t n_expert_used = ids->ne[0]; + const int64_t n_experts = ne02; + const int64_t n_tokens = ne12; + const int64_t ne_get_rows = n_tokens * n_expert_used; + + ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows); + ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows); + expert_bounds_dev.alloc(ctx.pool(), n_experts + 1); + + const int si1 = static_cast<int>(ids_s1); + const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]); + + GGML_ASSERT(sis1 > 0); + + ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(), + static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream()); + CUDA_CHECK(cudaGetLastError()); + + ids_info.ids_src_compact = ids_src_compact_dev.get(); + ids_info.ids_dst_compact = ids_dst_compact_dev.get(); + ids_info.expert_bounds_dev = expert_bounds_dev.get(); + ids_info.n_experts = static_cast<int>(n_experts); + ids_info.sis1 = sis1; + ids_info_ptr = &ids_info; + } + + const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; + const int rows_per_block = mmf_get_rows_per_block(cc); + + switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + constexpr int vals_per_T = 1; + mul_mat_f_switch_rows_per_block<float>( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } break; + case GGML_TYPE_F16: { + const half2 * src0_d = (const half2 *) src0->data; + constexpr int vals_per_T = 2; + mul_mat_f_switch_rows_per_block<half2>( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; + constexpr int vals_per_T = 2; + mul_mat_f_switch_rows_per_block<nv_bfloat162>( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } +} + +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, + const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) { + if (ggml_is_quantized(type)) { + return false; + } + + const size_t ts = ggml_type_size(type); + if (src0_ne[0] % (warp_size * (4/ts)) != 0) { + return false; + } + + if (src0_nb[0] != ts) { + return false; + } + + // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash: + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (src0_nb[i] % (2*ts) != 0) { + return false; + } + } + if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) { + return false; + } + + if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) { + return false; + } + + if (mul_mat_id) { + if (src0_ne[1] <= 1024 && src1_ncols > 512) { + return false; + } else if(src0_ne[1] > 1024 && src1_ncols > 128) { + return false; + } + } else { + if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) { + return false; + } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available. + return false; + } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + return false; + } else if (src1_ncols > 16) { + return false; + } + } + + switch (type) { + case GGML_TYPE_F32: + return ampere_mma_available(cc) || amd_mfma_available(cc); + case GGML_TYPE_F16: + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); + case GGML_TYPE_BF16: + return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); + default: + return false; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/mmf.cuh b/llama.cpp/ggml/src/ggml-cuda/mmf.cuh new file mode 100644 index 0000000..c2a8d54 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmf.cuh @@ -0,0 +1,908 @@ +#pragma once + +#include "mma.cuh" +#include "common.cuh" +#include "convert.cuh" + +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 +#define MMF_ROWS_PER_BLOCK_CDNA 64 + +static __forceinline__ int64_t mmf_get_max_block_size(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 512; + } else { + return 256; + } +} + +static __forceinline__ int mmf_get_padding(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 2; + } else { + return 4; + } +} + +static constexpr __device__ int mmf_get_padding() { +#if defined(AMD_MFMA_AVAILABLE) + return 2; +#else + return 4; +#endif // defined(AMD_MFMA_AVAILABLE) +} + +struct mmf_ids_data { + const int32_t * ids_src_compact = nullptr; + const int32_t * ids_dst_compact = nullptr; + const int32_t * expert_bounds_dev = nullptr; + int n_experts = 0; + int sis1 = 0; +}; + +void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id); + +template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids> +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#else +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { + NO_DEVICE_CODE; + return; + } + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + mmf_get_padding(); + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + int expert_idx = 0; + int col_base = 0; + + const int channel_dst = has_ids ? 0 : blockIdx.y; + + if constexpr (has_ids) { + // experts + tiles of ncols_dst are packed in the y dimension + int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block; + const int nchannels_x = gridDim.y / col_tiles; + const int tile_idx = blockIdx.y / nchannels_x; + expert_idx = blockIdx.y - tile_idx * nchannels_x; + col_base = tile_idx * cols_per_block; + } + + const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); + dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); + + if constexpr (has_ids) { + constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2; + const int64_t col_offset = col_base; + y += col_offset * stride_col_y * y_stride_scale; + dst += col_offset * stride_col_dst; + ids += col_offset * stride_row_id; + } + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + char * shmem_base = data_mmv; + int * slot_map = (int *) shmem_base; + char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + if constexpr (has_ids) { + int found = 0; + + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (threadIdx.x == 0) { + slot_map[j] = -1; + } + + if (col_base + j >= ncols_dst_total) { + continue; + } + + const int32_t * __restrict__ id_row = ids + j*stride_row_id; + + for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) { + int match = id_row[k*stride_col_id] == expert_idx; + + if (match) { + slot_map[j] = k; + found = 1; + break; + } + } + } + + if (!__syncthreads_or(found)) { + return; + } + } + + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v<T, float>) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } else { + const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; + tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; + } + } + } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp); + } else { + const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0; + float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp); + } + } + } else { + static_assert(std::is_same_v<T, void>, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; + + sum[i1] += buf_iw[j*kiw + i]; + } + } + + if constexpr (!has_ids) { +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } + } else { + const int slot = (j < cols_per_block) ? slot_map[j] : -1; + if (slot >= 0 && (col_base + j) < ncols_dst_total) { +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } + } + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + NO_DEVICE_CODE; +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +} + +//This kernel is for larger batch sizes of mul_mat_id +template <typename T, int rows_per_block, int cols_per_block, int nwarps> +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f_ids( + const T * __restrict__ x, const float * __restrict__ y, + const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 sis1_fd, const uint3 nch_fd) { +// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) + if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#else +#ifdef VOLTA_MMA_AVAILABLE + if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; + typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; +#else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T> tile_A; + typedef tile<8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; +#endif // VOLTA_MMA_AVAILABLE +#endif // defined(AMD_WMMA_AVAILABLE) + if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) { + NO_DEVICE_CODE; + return; + } + + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + mmf_get_padding(); + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = blockIdx.y; + const int expert_start = expert_bounds[expert_idx]; + const int expert_end = expert_bounds[expert_idx + 1]; + const int ncols_expert = expert_end - expert_start; + + const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block; + const int tile_idx = blockIdx.z; + if (tile_idx >= tiles_for_expert) { + return; + } + + const int col_base = tile_idx * cols_per_block; + + GGML_UNUSED(channel_ratio); + + const int channel_x = expert_idx; + const int sample_dst = 0; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row; + y += int64_t(sample_y) *stride_sample_y; + dst += int64_t(sample_dst)*stride_sample_dst; + + const int32_t * ids_src_expert = ids_src_compact + expert_start; + const int32_t * ids_dst_expert = ids_dst_compact + expert_start; + + extern __shared__ char data_mmv[]; + char * compute_base = data_mmv; + + //const float2 * y2 = (const float2 *) y; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + + if constexpr (std::is_same_v<T, float>) { + float vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float val = 0.0f; + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + val = y[channel*stride_channel_y + token*stride_col_y + col]; + } + } + vals[j0] = val; + } + }; + + gather_tile(0, vals_buf[0]); + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0]; + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) { + float2 vals_buf[2][tile_B::I]; + auto gather_tile = [&](int tile_idx_local, float2 *vals) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + tile_idx_local*tile_B::I; + const int global_j = col_base + j; + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block && global_j < ncols_expert) { + const int src_entry = ids_src_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd); + const int token = (int) qrm.x; + const int channel = (int) qrm.y; + if (token < ncols_dst_total) { + tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)]; + } + } + vals[j0] = tmp; + } + }; + + if (ntB > 0) { + gather_tile(0, vals_buf[0]); + } + + int curr_buf = 0; + int next_buf = 1; +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const float2 tmp = vals_buf[curr_buf][j0]; + tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp); + } + + if (itB + 1 < ntB) { + gather_tile(itB + 1, vals_buf[next_buf]); + } + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + + if (itB + 1 < ntB) { + curr_buf ^= 1; + next_buf ^= 1; + } + } + } else { + static_assert(std::is_same_v<T, void>, "unsupported type"); + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; + + sum[i1] += buf_iw[j * kiw + i]; + } + } + + const int global_j = col_base + j; + if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) { + const int dst_entry = ids_dst_expert[global_j]; + const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd); + const int token = (int) qrm.x; + if (token < ncols_dst_total) { + const int slot = (int) qrm.y; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } + } + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, + ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); + NO_DEVICE_CODE; +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +} + +template<typename T, int rows_per_block, int cols_per_block, int nwarps> +static inline void mul_mat_f_switch_ids( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream, + const mmf_ids_data * ids_data) { + const bool has_ids_data = ids_data && ids_data->ids_src_compact; + + // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16) + // we prefer the normal mul_mat_f path with has_ids=true. + if (has_ids_data && ncols_dst > 16) { + const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block); + if (max_tiles == 0) { + return; + } + dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles); + + const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); + const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); + + mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> + (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, + ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, + sis1_fd, nch_fd); + } else if (ids) { + const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block; + dim3 block_nums_ids = block_nums; + block_nums_ids.y *= col_tiles; + + mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> + (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>> + (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template <typename T, int rows_per_block, int cols_per_block> +void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream, const mmf_ids_data * ids_data) { + typedef tile<16, 8, T> tile_A_16; + typedef tile<32, 8, T> tile_A_32; + typedef tile<16, 8, T> tile_B_16; + typedef tile< 8, 8, T> tile_B_8; + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = mmf_get_max_block_size(cc); + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4; + const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; + const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; + + const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + + switch (nwarps_best) { + case 1: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + case 2: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + case 3: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + case 4: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + case 5: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + case 6: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + case 7: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + case 8: { + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>( + x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, + ids_data); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } + + GGML_UNUSED_VARS(nchannels_y); +} + +template <typename T, int rows_per_block> +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream, const mmf_ids_data * ids_data) { + + const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst; + + GGML_ASSERT(ids || ncols_dst <= 16); + + switch (ncols_case) { + case 1: { + mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 2: { + mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 3: { + mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 4: { + mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 5: { + mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 6: { + mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 7: { + mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 8: { + mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 9: { + mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 10: { + mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 11: { + mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 12: { + mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 13: { + mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 14: { + mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 15: { + mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case 16: { + mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +template <typename T> +static void mul_mat_f_switch_rows_per_block( + const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream, const mmf_ids_data * ids_data) { + switch (rows_per_block) { + case MMF_ROWS_PER_BLOCK: { + mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case MMF_ROWS_PER_BLOCK_CDNA: { + mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + default: + GGML_ABORT("unsupported rows_per_block: %i", rows_per_block); + } +} + +#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \ + template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \ + const T * x, const float * y, const int32_t * ids, float * dst, \ + const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t stride_col_id, const int64_t stride_row_id, \ + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ + cudaStream_t stream, const mmf_ids_data * ids_data); + +#if !defined(GGML_USE_MUSA) +#define DECL_MMF_CASE_EXTERN(ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) + +#define DECL_MMF_CASE(ncols_dst) \ + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) + +DECL_MMF_CASE_EXTERN(1); +DECL_MMF_CASE_EXTERN(2); +DECL_MMF_CASE_EXTERN(3); +DECL_MMF_CASE_EXTERN(4); +DECL_MMF_CASE_EXTERN(5); +DECL_MMF_CASE_EXTERN(6); +DECL_MMF_CASE_EXTERN(7); +DECL_MMF_CASE_EXTERN(8); +DECL_MMF_CASE_EXTERN(9); +DECL_MMF_CASE_EXTERN(10); +DECL_MMF_CASE_EXTERN(11); +DECL_MMF_CASE_EXTERN(12); +DECL_MMF_CASE_EXTERN(13); +DECL_MMF_CASE_EXTERN(14); +DECL_MMF_CASE_EXTERN(15); +DECL_MMF_CASE_EXTERN(16); +#else +#define DECL_MMF_CASE(ncols_dst) +#endif diff --git a/llama.cpp/ggml/src/ggml-cuda/mmid.cu b/llama.cpp/ggml/src/ggml-cuda/mmid.cu new file mode 100644 index 0000000..3c61e45 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmid.cu @@ -0,0 +1,164 @@ +#include "common.cuh" +#include "mmid.cuh" + +// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each. +struct mm_ids_helper_store { + uint32_t data; + + __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) { + data = (it & 0x003FFFFF) | (iex_used << 22); + } + + __device__ uint32_t it() const { + return data & 0x003FFFFF; + } + + __device__ uint32_t iex_used() const { + return data >> 22; + } +}; +static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store"); + +// Helper function for mul_mat_id, converts ids to a more convenient format. +// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert. +// ids_dst describes the same mapping but for the dst tensor. +// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1]. +template <int n_expert_used_template> +__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template; + const int expert = blockIdx.x; + + extern __shared__ char data_mm_ids_helper[]; + mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper; + + int nex_prev = 0; // Number of columns for experts with a lower index. + int it_compact = 0; // Running index for the compact slice of this expert. + + if constexpr (n_expert_used_template == 0) { + // Generic implementation: + for (int it = 0; it < n_tokens; ++it) { + int iex_used = -1; // The index at which the expert is used, if any. + for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) { + const int expert_used = ids[it*si1 + iex]; + nex_prev += expert_used < expert; + if (expert_used == expert) { + iex_used = iex; + } + } + + if (iex_used != -1) { + store[it_compact] = mm_ids_helper_store(it, iex_used); + } + + if (warp_reduce_any<warp_size>(iex_used != -1)) { + it_compact++; + } + } + } else { + // Implementation optimized for specific numbers of experts used: + static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used"); + const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2. + for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) { + const int it = it0 + threadIdx.x / neu_padded; + + const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any. + const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ? + ids[it*si1 + iex] : INT_MAX; + const int iex_used = expert_used == expert ? iex : -1; + nex_prev += expert_used < expert; + + // Whether the threads at this token position have used the expert: + const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1); + + // Do a scan over threads at lower token positions in warp to get the correct index for writing data: + int it_compact_add_lower = 0; +#pragma unroll + for (int offset = neu_padded; offset < warp_size; offset += neu_padded) { + const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size); + if (threadIdx.x >= static_cast<unsigned int>(offset)) { + it_compact_add_lower += tmp; + } + } + + if (iex_used != -1) { + store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used); + } + + // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads: + it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size); + } + } + nex_prev = warp_reduce_sum<warp_size>(nex_prev); + + for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) { + const mm_ids_helper_store store_it = store[itc]; + const int it = store_it.it(); + const int iex_used = store_it.iex_used(); + ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y; + ids_dst [nex_prev + itc] = it*n_expert_used + iex_used; + } + + if (threadIdx.x != 0) { + return; + } + + expert_bounds[expert] = nex_prev; + + if (expert < static_cast<int>(gridDim.x) - 1) { + return; + } + + expert_bounds[gridDim.x] = nex_prev + it_compact; +} + +template <int n_expert_used_template> +static void launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store"); + GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store"); + + const int id = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo); + + const dim3 num_blocks(n_experts, 1, 1); + const dim3 block_size(warp_size, 1, 1); + const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store); + GGML_ASSERT(nbytes_shared <= smpbo); + mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>> + (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1); +} + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds, + const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) { + switch (n_expert_used) { + case 2: + launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 4: + launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 6: + launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 8: + launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 16: + launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + case 32: + launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + default: + launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream); + break; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/mmid.cuh b/llama.cpp/ggml/src/ggml-cuda/mmid.cuh new file mode 100644 index 0000000..ac090ae --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmid.cuh @@ -0,0 +1,5 @@ +#pragma once + +void ggml_cuda_launch_mm_ids_helper( + const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds, + int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream); diff --git a/llama.cpp/ggml/src/ggml-cuda/mmq.cu b/llama.cpp/ggml/src/ggml-cuda/mmq.cu new file mode 100644 index 0000000..9a69f41 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmq.cu @@ -0,0 +1,366 @@ +#include "common.cuh" +#include "mmq.cuh" +#include "quantize.cuh" +#include "mmid.cuh" + +static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + switch (args.type_x) { + case GGML_TYPE_Q4_0: + mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream); + break; + case GGML_TYPE_MXFP4: + mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_mul_mat_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + + const char * src0_d = (const char *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + // If src0 is a temporary compute buffer, clear any potential padding. + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + const size_t size_data = ggml_nbytes(src0); + const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); + if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); + } + } + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc); + + // TODO: tighter pool buffer size vs q8 path + const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + + if (!ids) { + const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + if (use_native_mxfp4) { + static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); + quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + ne11, ne12, ne13, stream); + + } else { + quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + ne11, ne12, ne13, stream); + } + CUDA_CHECK(cudaGetLastError()); + } + + // Stride depends on quantization format + const int64_t s12 = use_native_mxfp4 ? + ne11 * ne10_padded * sizeof(block_fp4_mmq) / + (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) + : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); + const int64_t s13 = ne12*s12; + + const mmq_args args = { + src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d, + ne00, ne01, ne1, s01, ne11, s1, + ne02, ne12, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k, ne1}; + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); + return; + } + + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(nb12 % nb11 == 0); + GGML_ASSERT(nb2 % nb1 == 0); + + const int64_t n_expert_used = ids->ne[0]; + const int64_t ne_get_rows = ne12 * n_expert_used; + GGML_ASSERT(ne1 == n_expert_used); + + ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows); + ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1); + + { + GGML_ASSERT(ids->nb[0] == ggml_element_size(ids)); + const int si1 = ids->nb[1] / ggml_element_size(ids); + const int sis1 = nb12 / nb11; + + ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(), + ne02, ne12, n_expert_used, ne11, si1, sis1, stream); + CUDA_CHECK(cudaGetLastError()); + } + + const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 + + get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq); + ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1); + + const int64_t ne11_flat = ne12*n_expert_used; + const int64_t ne12_flat = 1; + const int64_t ne13_flat = 1; + + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + + if (use_native_mxfp4) { + quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } else { + quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); + } + CUDA_CHECK(cudaGetLastError()); + } + + const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); + const int64_t s13 = ne12*s12; + + // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. + const mmq_args args = { + src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d, + ne00, ne01, ne_get_rows, s01, ne_get_rows, s1, + ne02, ne02, s02, s12, s2, + ne03, ne13, s03, s13, s3, + use_stream_k, ne12}; + + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); +} + +void ggml_cuda_op_mul_mat_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + const int64_t row_diff = row_high - row_low; + const int64_t stride01 = ne00 / ggml_blck_size(src0->type); + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + + // The stream-k decomposition is only faster for recent NVIDIA GPUs. + // Also its fixup needs to allocate a temporary buffer in the memory pool. + // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || GGML_CUDA_CC_IS_CDNA(cc)) + && src1_ncols == ne11; + const mmq_args args = { + src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, + ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, + 1, 1, 0, 0, 0, + 1, 1, 0, 0, 0, + use_stream_k, src1_ncols}; + + ggml_cuda_mul_mat_q_switch_type(ctx, args, stream); + + GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size); +} + +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) { +#ifdef GGML_CUDA_FORCE_CUBLAS + return false; +#endif // GGML_CUDA_FORCE_CUBLAS + + bool mmq_supported; + + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + mmq_supported = true; + break; + default: + mmq_supported = false; + break; + } + + if (!mmq_supported) { + return false; + } + + if (turing_mma_available(cc)) { + return true; + } + + if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) { + return false; + } + +#ifdef GGML_CUDA_FORCE_MMQ + return true; +#endif //GGML_CUDA_FORCE_MMQ + + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + } + + if (amd_mfma_available(cc)) { + // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT) + // performs better but is currently suffering from a crash on this architecture. + // TODO: Revisit when hipblaslt is fixed on CDNA3 + if (GGML_CUDA_CC_IS_CDNA3(cc)) { + return true; + } + if (n_experts > 64 || ne11 <= 128) { + return true; + } + if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + return true; + } + if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) { + return true; + } + return false; + } + + if (amd_wmma_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + // High expert counts are almost always better on MMQ due to + // the synchronization overhead in the cuBLAS/hipBLAS path: + // https://github.com/ggml-org/llama.cpp/pull/18202 + if (n_experts >= 64) { + return true; + } + + // For some quantization types MMQ can have lower peak TOPS than hipBLAS + // so it's only faster for sufficiently small batch sizes: + switch (type) { + case GGML_TYPE_Q2_K: + return ne11 <= 128; + case GGML_TYPE_Q6_K: + return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256); + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128; + default: + return true; + } + } + + // For RDNA4 MMQ is consistently faster than dequantization + hipBLAS: + // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301 + return true; + } + + return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; + +} diff --git a/llama.cpp/ggml/src/ggml-cuda/mmq.cuh b/llama.cpp/ggml/src/ggml-cuda/mmq.cuh new file mode 100644 index 0000000..f80f98c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmq.cuh @@ -0,0 +1,4092 @@ +#pragma once + +#include "common.cuh" +#include "vecdotq.cuh" +#include "mma.cuh" + +#include <climits> +#include <cstdint> + +using namespace ggml_cuda_mma; + +#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. +#define MMQ_ITER_K 256 +#define MMQ_ITER_K_MXFP4_FP4 512 +#define MMQ_NWARPS 8 + +typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); +typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted, + float * __restrict__ dst, const int stride, const int i_max, const int j_max); + +enum mmq_q8_1_ds_layout { + MMQ_Q8_1_DS_LAYOUT_D4, + MMQ_Q8_1_DS_LAYOUT_DS4, + MMQ_Q8_1_DS_LAYOUT_D2S6, +}; + +struct block_q8_1_mmq { + // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block. + // The y float data is first grouped as blocks of 128 values. + // These blocks are then treated as individual data values and transposed. + // + // To avoid shared memory bank conflicts each block is padded with 16 bytes. + // This padding is also used to store block scales/partial sums. + // The scales multiplied with the quantized data are equal to the unquantized values. + // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization) + // and are only needed for performance reasons. + // + // The exact data stored depends on the x data type. + union { + float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3 + half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3 + half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values, + // stored as d0,d1,s1,s2,s3,s4,s5 + }; + int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each +}; + +struct block_fp4_mmq { + uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc. + int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values +}; + +static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size"); +static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size"); + +static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { + switch (type_x) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q5_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q5_1: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q8_0: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_MXFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q2_K: + return MMQ_Q8_1_DS_LAYOUT_D2S6; + case GGML_TYPE_Q3_K: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_IQ1_S: + return MMQ_Q8_1_DS_LAYOUT_DS4; + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + return MMQ_Q8_1_DS_LAYOUT_D4; + default: + GGML_ABORT("fatal error"); + break; + } +} + +struct tile_x_sizes { + int qs; + int dm; + int sc; +}; + +static int get_mmq_x_max_host(const int cc) { + return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? +#ifdef GGML_CUDA_FORCE_MMQ + 128 : 64; +#else + MMQ_DP4A_MAX_BATCH_SIZE : 64; +#endif // GGML_CUDA_FORCE_MMQ +} + +static constexpr __device__ int get_mmq_x_max_device() { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + return 128; +#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + +#if defined(GGML_USE_HIP) + return 64; +#else // defined(GGML_USE_HIP) + +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#ifdef GGML_CUDA_FORCE_MMQ + return 128; +#else // GGML_CUDA_FORCE_MMQ + return MMQ_DP4A_MAX_BATCH_SIZE; +#endif // GGML_CUDA_FORCE_MMQ +#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + +#endif // defined(GGML_USE_HIP) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +} + +static int get_mmq_y_host(const int cc) { + return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) : + ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64); +} + +static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) { +#if defined(BLACKWELL_MMA_AVAILABLE) + return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K; +#else + return MMQ_ITER_K; +#endif // defined(BLACKWELL_MMA_AVAILABLE) +} + +static constexpr __device__ int get_mmq_y_device() { +#if defined(GGML_USE_HIP) +#if defined(RDNA1) + return 64; +#else + return 128; +#endif // defined RDNA1 +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + return 128; +#else + return 64; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +} + +// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. +// The K dimension of the tiles has either, +// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), +// 32 bit elements for the quantized data (does not include scales). +// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. +// The final tile size in K direction is padded to avoid shared memory bank conflicts, +// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. +#define MMQ_TILE_NE_K 32 + +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} + +static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; + case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; + case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; + case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K; + case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16; + case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0; + default: return tile_x_sizes{0, 0, 0}; + } +} + +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) + +static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); +static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); + +static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; + // tile sizes are the same for Q8_1 and FP4 for blackwell + case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; + case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K; + case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K; + case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0; + case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0; + default: return 0; + } +} + +// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1) +#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K + +static int mmq_get_granularity_host(const int mmq_x, const int cc) { + if (amd_mfma_available(cc) || amd_wmma_available(cc)) { + return mmq_x >= 128 ? 32 : 16; + } else if (turing_mma_available(cc) && mmq_x >= 48) { + return 16; + } else { + return 8; + } +} + +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 128 ? 32 : 16; +} +#elif defined(TURING_MMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 48 ? 16 : 8; +} +#else +static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { + return 8; +} +#endif // AMD_MFMA_AVAILABLE + +#if defined(GGML_USE_HIP) +static int mmq_get_nwarps_host(const int cc, const int warp_size) { + return amd_mfma_available(cc) ? 8 : 256/warp_size; +} +#else +static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { + return 256/warp_size; +} +#endif // (GGML_USE_HIP) + +static constexpr __device__ int mmq_get_nwarps_device() { +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + return 8; +#else + return 256/ggml_cuda_get_physical_warp_size(); +#endif // AMD_MFMA_AVAILABLE +} + +// ------------------------------------------------------------ + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_0; + const int kqsx = txi % QI4_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b2(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_0_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ> + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_1; + const int kqsx = txi % QI4_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b4(bxi->qs, kqsx); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); + + int u[2*VDR_Q4_1_Q8_1_MMQ]; + +#pragma unroll + for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { + u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; + u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + } + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ> + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_0; + const int kqsx = txi % QI5_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_b2(bxi->qs, kqsx); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_1; + const int kqsx = txi % QI5_1; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; + + const int ql = get_int_b4(bxi->qs, kqsx); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); + + int qs0 = (ql >> 0) & 0x0F0F0F0F; + qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 + qs0 |= (qh << 11) & 0x00001000; // 1 -> 12 + qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 + qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 + + int qs1 = (ql >> 4) & 0x0F0F0F0F; + qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 + qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12 + qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 + qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#else + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp + constexpr int threads_per_row = 32; + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI8_0; + const int kqsx = txi % QI8_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI_MXFP4; + const int kqsx = txi % QI_MXFP4; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b1(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + const int k0 = kbx * (2 * QI_MXFP4) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#else + x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> +static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kbx0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + int * x_qs = (int *) x_tile; + uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + + const int txi = threadIdx.x; + + constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4); + + constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = txi % threads_per_row; + const int row_in_warp = txi / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx; + + // quantize_mxfp4_mmq permutes nibbles to match the quantized format + const int k0 = kbx * 4; + memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16); + + // Load E8M0 scales: pack 2 consecutive scales into one uint32 + if (kbx % 2 == 0) { + uint32_t e = bxi->e; + e |= ((bxi + 1)->e << 8); + x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e; + } + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ> + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); + } + } + } +} + +template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout> +static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + float dB; + const int j = j0 + tile_C::get_j(0); + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; + float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + tile_B B; + float dB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_0], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2]; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x, + const int * __restrict__ y, + float * __restrict__ sum, + const int k00) { + typedef tile<16, 8, int> tile_A; + typedef tile<8, 8, int> tile_B; + typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K); + + // Match layout from load_tiles_mxfp4_fp4 + const int * x_qs = (const int *) x; + const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + const int * y_qs = (const int *) y + 4; + const uint32_t * y_sc = (const uint32_t *) y; + + // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4 + tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; + uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; + + // Block scale + // Each thread has to point to a 4 byte scale value + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, + MMQ_MMA_TILE_X_K_FP4); + + // based on block-scaling document, 2 threads in each quad need to supply to the scale value + const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8; + scaleA[n][k01 / (2 * QI_MXFP4)] = + *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4)); + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { + tile_B B; + uint32_t scaleB; // 2xN scales + + load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K); + + scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + + mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; + } + } + } + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ> + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; + + const int i0 = (threadIdx.y/ntx)*rows_per_warp; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B; + float2 dsB[tile_C::ne/2]; + + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n][k01/QI8_1], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + } + } + } + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +} + +// Used for Q3_K, IQ2_S, and IQ2_XS +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], + &y_qs[j*MMQ_TILE_Y_K + k01], + &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +// Used for Q3_K, IQ2_S, and IQ2_XS: +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); + } + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); + constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + +#pragma unroll + for (int l = 0; l < QR2_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + const int sc_m = bxi->scales[kqsx]; +#ifdef FAST_FP16_AVAILABLE + const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4)); +#else + const float2 bxi_dmf = __half22float2(bxi->dm); + const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); +#endif // FAST_FP16_AVAILABLE + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; +#else + x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + float2 y_df[mmq_x/nwarps]; +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 2; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } + + // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. + // As a workaround 2 separate loops are used instead. +#pragma unroll + for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + constexpr int ns = 1; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + } + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B[0]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; +#pragma unroll + for (int l = 0; l < tile_A::ne; ++l) { + A1.x[l] = 0x01010101; + } + mma(Cm, A1, B); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile<16, 8, int> tile_A_8; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + float dA[ntx][tile_C::ne/2][8]; + float mA[ntx][tile_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { + const int k0 = k00 + k01; + + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); + + dA[n][l][k01/(QI8_1/2)] = dm.x; + mA[n][l][k01/(QI8_1/2)] = dm.y; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float2 dB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + tile_B B[2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); + + tile_C Cm[2]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm[0], A1, B[0]); + mma(Cm[1], A1, B[1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd[2]; + + mma(Cd[0], A[n][k01/4 + 0], B[0]); + mma(Cd[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); + } + } + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { + float2 sB[tile_C::ne/2]; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y; + } + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + const int x_ql_0 = get_int_b2(bxi->qs, kqsx); + const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2))); + +#pragma unroll + for (int l = 0; l < QR3_K; ++l) { + const int k = (kqsx/8)*32 + l*8 + kqsx % 8; + + const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303; + const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404; + + const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + const int ksc = threadIdx.x % 4; + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + const int8_t * sc8 = (const int8_t *) ≻ + const float d = bxi->d; + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; + } +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + +#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)) +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + x_df[i] = bxi->d; + } +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE) +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { + // scale arrangement after the following two lines: + // - ksc == 0: sc0, sc1, sc2, sc3 + // - ksc == 1: sc4, sc5, sc6, sc7 + // - ksc == 2: m0, m1, m2, m3 + // - ksc == 3: m4, m5, m6, m7 + return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits + ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int qs0 = get_int_b4(bxi->qs, txi); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#else + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + + #pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; + } + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_dm = (half2 *) (x_qs + txs.qs); + int * x_sc = (int *) (x_dm + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const int ky = QR5_K*txi; + + const int ql = get_int_b4(bxi->qs, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } + } +#else +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); + + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; + } +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + txs.qs; + const int * x_sc = (const int *) x_dm + txs.dm; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); + int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); + int * x_sc = (int *) (x_df + txs.dm); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + + const int ql = get_int_b2(bxi->ql, txi); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); + const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; + + const int kq0 = 2*txi - txi % (QI6_K/2) + 0; + const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; +#else + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 4; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); +#else + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * x_sc = (const int *) x_df + txs.dm; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); + + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template <int mmq_x, int mmq_y> +static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 8, int, input_layout> tile_A; + typedef tile<16, 8, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + typedef tile<64, 2, int, input_layout> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles + constexpr data_layout input_layout = get_input_data_layout(); + typedef tile<16, 4, int, input_layout> tile_A; + typedef tile<16, 4, int, input_layout> tile_B; + typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(TURING_MMA_AVAILABLE) + + typedef tile<16, 4, int> tile_A; + typedef tile< 8, 4, int> tile_B; + typedef tile<16, 8, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I); + + tile_A A[ntx][8]; + int scA[ntx][tile_C::ne/2][8]; + float dA[ntx][tile_C::ne/2]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + const int k0 = k00 + k01; + + load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { + const int k0 = k00 + k01; + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16]; + const int8_t * sc = (const int8_t *) &sc_packed; + +#pragma unroll + for (int ksc = 0; ksc < sizeof(int); ++ksc) { + scA[n][l][k01/4 + ksc] = sc[ksc]; + } + } + } + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); + + dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + float tmp[ntx][tile_C::ne] = {{0.0f}}; + +#pragma unroll + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { + tile_B B[2]; + float dB[tile_C::ne/2]; + + // Here load_generic is faster than load_ldmatrix. + load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K); + load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < tile_C::ne/2; ++l) { + const int j = j0 + tile_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C[2]; + mma(C[0], A[n][k01/4 + 0], B[0]); + mma(C[1], A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2]; + } + } + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2]; + } + } + } +#else + GGML_UNUSED_VARS(x, y, sum, k00); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_NL; + const int kqsx = txi % QI4_NL; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx; + + const int aux_q4 = get_int_b2(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = kbx * (2 * QI4_NL) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; + const int kbxd = threadIdx.x % blocks_per_tile_x_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#else + x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + + const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); + +#pragma unroll + for (int l = 0; l < QR2_XXS; ++l) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + + const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + + #pragma unroll + for (int l = 0; l < QR2_XS; ++l) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR2_S; ++l) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + + const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); + +#pragma unroll + for (int l = 0; l < QR3_XXS; ++l) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + + const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR3_S; ++l) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], + iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); + const float d = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + #pragma unroll + for (int l = 0; l < QR1_S/2; ++l) { + const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#else + x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + + const int aux_q4 = get_int_b4(bxi->qs, kqsx); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + const int k0 = 8 * (kqsx / 4) + kqsx % 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + + constexpr int rows_per_warp = warp_size / 8; +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; + + const float d = __half2float(bxi->d); + + const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) + | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#else + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + +template<int mmq_x, int mmq_y, bool need_check> +static __device__ __forceinline__ void mmq_write_back_dp4a( + const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +template<ggml_type type, int mmq_x, int mmq_y, bool need_check> +static __device__ __forceinline__ void mmq_write_back_mma( + const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, + const int stride, const int i_max, const int j_max) { + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr int tileC_IJ = mmq_get_granularity_device(0); + typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C; + constexpr int rows_per_warp = granularity; +#else + typedef tile<16, 8, int> tile_C; + constexpr int rows_per_warp = 2 * granularity; +#endif // defined(AMD_MFMA_AVAILABLE) + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); +#else + GGML_UNUSED(nwarps); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l); + + if (j > j_max) { + continue; + } + + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l]; + } + } + } +} + +// ------------------------------------------------------------------------------------------------------------------------------------- + +template <int mmq_x, int mmq_y, bool need_check, ggml_type type> +struct mmq_type_traits; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> { + static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> { + static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> { + static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> { + static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> { + static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> { + static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; +#ifdef BLACKWELL_MMA_AVAILABLE + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>; +#else + static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; +#endif // BLACKWELL_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> { + static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> { + static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> { + static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> { + static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> { + static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> { + static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> { + static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> { + static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> { + static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> { + static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> { + static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> { + static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> { + static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + +template <ggml_type type, int mmq_x, bool need_check, bool fixup> +static __device__ __forceinline__ void mul_mat_q_process_tile( + const char * __restrict__ x, const int offset_x, const int * __restrict__ y, + const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int mmq_y = get_mmq_y_device(); + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles; + + extern __shared__ int data_mul_mat_q[]; + int * tile_y = data_mul_mat_q + mmq_x; + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>; +#else + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + +#if defined(BLACKWELL_MMA_AVAILABLE) + // FP4 tile stores 8 blocks + constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; +#else + constexpr int ne_block = 4 * QK8_1; +#endif // defined(BLACKWELL_MMA_AVAILABLE) + + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int); + + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { + load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); + { + const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz; +#pragma unroll + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, 0); + + __syncthreads(); + + { + const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz); +#pragma unroll + for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; + + tile_y[l] = by0[l]; + } + } + + __syncthreads(); + + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); + + __syncthreads(); + } + + if (fixup) { + write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + } else { + write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j); + } +} + + +// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 + +template <ggml_type type, int mmq_x, bool need_check> +#if defined(GGML_USE_HIP) +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) +#else +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) +#else + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(GGML_USE_HIP) +static __global__ void mul_mat_q( + const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, + const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ncols_max) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { + NO_DEVICE_CODE; + return; + } + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int mmq_y = get_mmq_y_device(); + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + + // Initialize the ids for writing back data with just the index. + // For regular matrix multiplications this is never changed. + // For MoE the correct indices are loaded from ids_dst. + extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + + // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + { + const int wt = blockIdx.z / nchannels_y; + const int zt = blockIdx.z - wt*nchannels_y; + const int jt = blockIdx.y; + const int it = blockIdx.x; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // __syncthreads(); // There is no previous tile that could cause a race condition. +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; + mul_mat_q_process_tile<type, mmq_x, need_check, fixup> + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + return; + } +#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA + + constexpr int ITER_K = get_iter_k(type); + + const int64_t blocks_per_ne00 = ncols_x / qk; + constexpr int blocks_per_iter = ITER_K / qk; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + + // kb0 == k index when doing the matrix multiplication for an output tile. + int kb0_start = kbc % blocks_per_ne00; + int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + + continue; + } + + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + } + + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + mul_mat_q_process_tile<type, mmq_x, need_check, fixup> + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); + + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + int tmp = kbc; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + // Defaults for regular matrix multiplication: + int col_low = 0; + int col_high = ncols_dst; + int col_diff = ncols_dst; + int offset_y = wt*stride_sample_y + zt*stride_channel_y; + int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst; + + if (ids_dst) { + col_low = expert_bounds[zt + 0]; + col_high = expert_bounds[zt + 1]; + col_diff = col_high - col_low; + + offset_y = 0; + offset_dst = 0; + + if (jt*mmq_x >= col_diff) { + return; + } + + // The memory layout for the fixup buffer is always contiguous, therefore reset ids: + __syncthreads(); +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; + + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { + break; + } + + ids_dst_shared[j] = j; + } + __syncthreads(); + } + + offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int)); + offset_dst += it*mmq_y; + + const int tile_x_max_i = nrows_x - it*mmq_y - 1; + const int tile_y_max_j = col_diff - jt*mmq_x - 1; + + const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + + constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. + mul_mat_q_process_tile<type, mmq_x, need_check, fixup> + (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, + tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); +} + +template <ggml_type type, int mmq_x, bool need_check> +static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, + const int32_t * expert_bounds, + float * __restrict__ dst, + const float * __restrict__ tmp_last_tile, + const int ncols_x, + const int nrows_x, + const int ncols_dst, + const size_t stride_col_dst, + const int nchannels_y, + const size_t stride_channel_dst, + const int nsamples_y, + const size_t stride_sample_dst, + const int ncols_max) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int ITER_K = get_iter_k(type); + + constexpr int blocks_per_iter = ITER_K / qk; + const int64_t blocks_per_ne00 = ncols_x / qk; + + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + + const int ntx = (ncols_max + mmq_x - 1) / mmq_x; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; + + const int bidx0 = blockIdx.x; + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + + kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; + const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + bool any_fixup = false; + + // Iterate over previous blocks and sum up partial sums written to fixup buffer. + // All CUDA blocks that get here must have a previous block that needs a fixup. + int64_t bidx = bidx0 - 1; + int64_t kbc_stop = kbc0; + while(true) { + int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + any_fixup = true; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + } + } + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + break; + } + bidx--; + kbc_stop = kbc; + } + + if (!any_fixup) { + return; + } + + int tmp = kbc0; + const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); + const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); + tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); + const int zt = tmp / (ntx*blocks_per_ne00); + tmp -= zt * (ntx*blocks_per_ne00); + const int jt = tmp / blocks_per_ne00; + + if (!ids_dst) { + const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = ncols_dst - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } + return; + } + + __shared__ int ids_dst_shared[mmq_x]; + const int col_low = expert_bounds[zt + 0]; + const int col_high = expert_bounds[zt + 1]; + const int col_diff = col_high - col_low; + + for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { + ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j]; + } + __syncthreads(); + + const int offset_dst = it*mmq_y; + dst += offset_dst; + + const int i_max = nrows_x - it*mmq_y - 1; + const int j_max = col_diff - jt*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; + } + } +} + +struct mmq_args { + const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst; + int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst; + int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst; + int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst; + bool use_stream_k; int64_t ncols_max; +}; + +template<ggml_type type> +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { + const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); + const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); + const size_t nbs_ids = mmq_x*sizeof(int); + const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq)); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); +} + +template <ggml_type type, int mmq_x> +static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int nsm = ggml_cuda_info().devices[id].nsm; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + const int mmq_y = get_mmq_y_host(cc); + + const dim3 block_dims(warp_size, nwarps, 1); + + const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps); + + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared); + + const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; + const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x; + const int ntzw = args.nchannels_y * args.nsamples_y; + const dim3 block_nums_xy_tiling(nty, ntx, ntzw); + + GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0); + GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0); + const int channel_ratio = args.nchannels_y / args.nchannels_x; + const int sample_ratio = args.nsamples_y / args.nsamples_x; + + if (!args.use_stream_k) { + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + } + return; + } + + const dim3 block_nums_stream_k(nsm, 1, 1); + const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + + ggml_cuda_pool & pool = ctx.pool(id); + ggml_cuda_pool_alloc<float> tmp_fixup(pool); + if (fixup_needed) { + tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); + } + + if (args.nrows_x % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } else { + constexpr bool need_check = true; + mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>> + (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, + args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + args.ncols_max); + + if (!fixup_needed) { + return; + } + + mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, + args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, + args.ncols_max); + } +} + +template <ggml_type type> +void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc, warp_size); + + const int mmq_x_max = get_mmq_x_max_host(cc); + const int mmq_y = get_mmq_y_host(cc); + + int mmq_x_best = 0; + int ntiles_x_best = INT_MAX; + + for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { + const int granularity = mmq_get_granularity_host(mmq_x, cc); + + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { + continue; + } + + const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x; + + if (ntiles_x < ntiles_x_best) { + mmq_x_best = mmq_x; + ntiles_x_best = ntiles_x; + } + } + + switch (mmq_x_best) { + case 8: + launch_mul_mat_q<type, 8>(ctx, args, stream); + break; + case 16: + launch_mul_mat_q<type, 16>(ctx, args, stream); + break; + case 24: + launch_mul_mat_q<type, 24>(ctx, args, stream); + break; + case 32: + launch_mul_mat_q<type, 32>(ctx, args, stream); + break; + case 40: + launch_mul_mat_q<type, 40>(ctx, args, stream); + break; + case 48: + launch_mul_mat_q<type, 48>(ctx, args, stream); + break; + case 56: + launch_mul_mat_q<type, 56>(ctx, args, stream); + break; + case 64: + launch_mul_mat_q<type, 64>(ctx, args, stream); + break; + case 72: + launch_mul_mat_q<type, 72>(ctx, args, stream); + break; + case 80: + launch_mul_mat_q<type, 80>(ctx, args, stream); + break; + case 88: + launch_mul_mat_q<type, 88>(ctx, args, stream); + break; + case 96: + launch_mul_mat_q<type, 96>(ctx, args, stream); + break; + case 104: + launch_mul_mat_q<type, 104>(ctx, args, stream); + break; + case 112: + launch_mul_mat_q<type, 112>(ctx, args, stream); + break; + case 120: + launch_mul_mat_q<type, 120>(ctx, args, stream); + break; + case 128: + launch_mul_mat_q<type, 128>(ctx, args, stream); + break; + default: + fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); + GGML_ABORT("fatal error"); + break; + } +} + +#define DECL_MMQ_CASE(type) \ + template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ + +extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); +extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); +extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); +extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); +extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); + +// ------------------------------------------------------------------------------------------------------------------------- + +void ggml_cuda_mul_mat_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + +void ggml_cuda_op_mul_mat_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); + +bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvf.cu b/llama.cpp/ggml/src/ggml-cuda/mmvf.cu new file mode 100644 index 0000000..d914720 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmvf.cu @@ -0,0 +1,862 @@ +#include "ggml.h" +#include "common.cuh" +#include "unary.cuh" +#include "mmvf.cuh" +#include "convert.cuh" + +template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false> +static __global__ void mul_mat_vec_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride) { + const int row = blockIdx.x; + // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) + const int channel_dst = blockIdx.y; + const int tid = threadIdx.x; + + int token_idx; + int channel_x; + int channel_y; + int sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + token_idx = ids ? blockIdx.z : 0; + channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio); + channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst; + sample_dst = ids ? 0 : blockIdx.z; + } + + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); + const int sample_y = sample_dst; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; + dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y2*2; + dst += token_idx*stride_col_dst; + } + + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU; + const T * gate_x = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr; + glu_op = fusion.glu_op; + + if (use_gate) { + gate_x = static_cast<const T *>(fusion.gate); + } + if (use_bias) { + x_bias = static_cast<const float *>(fusion.x_bias); + } + if (use_gate_bias) { + gate_bias = static_cast<const float *>(fusion.gate_bias); + use_gate_bias = use_gate; + } else { + use_gate_bias = false; + } + } + + if (use_gate) { + gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + } + + const int channel_bias = ids ? channel_x : channel_dst; + + if constexpr (has_fusion) { + if (use_bias) { + x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + if (use_gate_bias) { + gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; + } + } + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + float * buf_iw_gate = nullptr; + if constexpr (has_fusion) { + buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); + } + + if (block_size > warp_size) { + if (tid < warp_size) { + buf_iw[tid] = 0.0f; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid] = 0.0f; + } + } + } + __syncthreads(); + } + + float sumf[ncols_dst] = {0.0f}; + float sumf_gate[ncols_dst]; + if constexpr (has_fusion) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = 0.0f; + } + } + + if constexpr (std::is_same_v<T, float>) { + const float2 * x2 = (const float2 *) x; + const float2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const float2 *) gate_x; + } + } + + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = x2[col2]; + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } + } else if constexpr (std::is_same_v<T, half>) { + const half2 * x2 = (const half2 *) x; + const half2 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const half2 *) gate_x; + } + } + + if (std::is_same_v<type_acc, float>) { + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + float2 tmpx_gate = make_float2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = __half22float2(gate_x2[col2]); + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; + half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}}; + + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const half2 tmpx = x2[col2]; + half2 tmpx_gate = make_half2(0.0f, 0.0f); + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y); + } + } + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); + } + + if constexpr (has_fusion) { + if (use_gate) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]); + } + } + } +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + } else if constexpr (std::is_same_v<T, nv_bfloat16>) { +//TODO: add support for ggml_cuda_mad for hip_bfloat162 +#if defined(GGML_USE_HIP) + const int * x2 = (const int *) x; + const int * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const int *) gate_x; + } + } + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const int tmpx = x2[col2]; + int tmpx_gate = 0; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]); + const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]); + ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]); + const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]); + ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y); + } + } + } + } +#else + const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; + const nv_bfloat162 * gate_x2 = nullptr; + if constexpr (has_fusion) { + if (use_gate) { + gate_x2 = (const nv_bfloat162 *) gate_x; + } + } + for (int col2 = tid; col2 < ncols2; col2 += block_size) { + const nv_bfloat162 tmpx = x2[col2]; + nv_bfloat162 tmpx_gate; + if constexpr (has_fusion) { + if (use_gate) { + tmpx_gate = gate_x2[col2]; + } + } +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const float2 tmpy = y2[j*stride_col_y2 + col2]; + ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); + ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); + + if constexpr (has_fusion) { + if (use_gate) { + ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x); + ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y); + } + } + } + } +#endif + } else { + static_assert(std::is_same_v<T, void>, "unsupported type"); + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + sumf[j] = warp_reduce_sum<warp_size>(sumf[j]); + + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]); + } + } + + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf[j]; + if constexpr (has_fusion) { + if (use_gate) { + buf_iw_gate[tid/warp_size] = sumf_gate[j]; + } + } + __syncthreads(); + if (tid < warp_size) { + sumf[j] = buf_iw[tid]; + sumf[j] = warp_reduce_sum<warp_size>(sumf[j]); + if constexpr (has_fusion) { + if (use_gate) { + sumf_gate[j] = buf_iw_gate[tid]; + sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]); + } + } + } + + if (j < ncols_dst) { + __syncthreads(); + } + } + } + + if (tid >= ncols_dst) { + return; + } + + float value = sumf[tid]; + + if constexpr (has_fusion) { + if (use_bias) { + value += x_bias[tid*stride_col_dst + row]; + } + + if (use_gate) { + float gate_value = sumf_gate[tid]; + if (use_gate_bias) { + gate_value += gate_bias[tid*stride_col_dst + row]; + } + switch (glu_op) { + case GGML_GLU_OP_SWIGLU: + value *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + value *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + value = ggml_cuda_op_swiglu_oai_single(gate_value, value); + break; + } + default: + break; + } + } + } + + dst[tid*stride_col_dst + row] = value; + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate); + } +} + +template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false> +static void mul_mat_vec_f_switch_fusion( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const uint3 nchannels_y, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if constexpr (ncols_dst == 1) { + if (has_fusion) { + mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); + return; + } + } + + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + + mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); + +} + +template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false> +void launch_mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t block_size_best = warp_size; + int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); + int64_t max_block_size = 256; + if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { + max_block_size = 128; + } + for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + + const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); + const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + case 64: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + case 96: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + case 128: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + case 160: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + case 192: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + case 224: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + case 256: { + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +template <typename T, typename type_acc> +static void mul_mat_vec_f_cuda_switch_ncols_dst( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const int64_t ids_stride, cudaStream_t stream) { + + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + + if (has_ids) { + // Single-token MUL_MAT_ID path + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + + switch (ncols_dst) { + case 1: + launch_mul_mat_vec_f_cuda<T, type_acc, 1> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + case 2: + launch_mul_mat_vec_f_cuda<T, type_acc, 2> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + case 3: + launch_mul_mat_vec_f_cuda<T, type_acc, 3> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + case 4: + launch_mul_mat_vec_f_cuda<T, type_acc, 4> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + case 5: + launch_mul_mat_vec_f_cuda<T, type_acc, 5> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + case 6: + launch_mul_mat_vec_f_cuda<T, type_acc, 6> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + case 7: + launch_mul_mat_vec_f_cuda<T, type_acc, 7> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + case 8: + launch_mul_mat_vec_f_cuda<T, type_acc, 8> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +template<typename T> +static void mul_mat_vec_f_cuda( + const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) { + + if constexpr(std::is_same_v<T, half>) { + if (prec == GGML_PREC_DEFAULT) { + mul_mat_vec_f_cuda_switch_ncols_dst<T, half> + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + return; + } + } + mul_mat_vec_f_cuda_switch_ncols_dst<T, float> + (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); +} + +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE); + GGML_ASSERT(ne13 == ne3); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s13 = src1->nb[3] / ts_src1; + const int64_t s3 = dst->nb[3] / ts_dst; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + + switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); + } break; + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0->data; + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } +} + +void ggml_cuda_op_mul_mat_vec_f( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne10 = src1->ne[0]; + const int64_t ne0 = dst->ne[0]; + const int64_t row_diff = row_high - row_low; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00; + const int64_t stride_col_y = ne10; + const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t nchannels_dst = 1; + const int64_t stride_channel_x = 0; + const int64_t stride_channel_y = 0; + const int64_t stride_channel_dst = 0; + const int64_t nsamples_x = 1; + const int64_t nsamples_dst = 1; + const int64_t stride_sample_x = 0; + const int64_t stride_sample_y = 0; + const int64_t stride_sample_dst = 0; + + ggml_cuda_mm_fusion_args_device empty{}; + switch (src0->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0_dd_i; + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); + } break; + case GGML_TYPE_F16: { + const half * src0_d = (const half *) src0_dd_i; + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); + } break; + case GGML_TYPE_BF16: { + const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; + mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); + } break; + default: + GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); + } + + GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size); +} + +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) { + if (src0_ne[0] % 2 != 0) { + return false; + } + + const size_t ts = ggml_type_size(type); + if (src0_nb[0] != ts) { + return false; + } + + // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash: + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (src0_nb[i] % (2*ts) != 0) { + return false; + } + } + + switch (type) { + case GGML_TYPE_F32: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (ampere_mma_available(cc)) { + return ne11 <= 3; + } + if (cc >= GGML_CUDA_CC_TURING) { + return ne11 <= 4; + } + return ne11 <= 3; + } else if (GGML_CUDA_CC_IS_AMD(cc)) { + if (fp32_mma_hardware_available(cc)) { + return ne11 <= 3; + } + return ne11 <= 8; + } + return ne11 <= 8; + case GGML_TYPE_F16: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (ampere_mma_available(cc)) { + return src0_small && ne11 == 1; + } + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return src0_small && ne11 <= 4; + } + if (fp16_mma_hardware_available(cc)) { + return src0_small && ne11 <= 3; + } + return ne11 <= 8; + } else if (GGML_CUDA_CC_IS_AMD(cc)) { + if (fp16_mma_hardware_available(cc)) { + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return ne11 <= 3; + } + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return ne11 <= 5; + } + return ne11 <= 2; + } + return ne11 <= 8; + } + return ne11 <= 8; + case GGML_TYPE_BF16: + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1); + if (ampere_mma_available(cc)) { + return src0_small && ne11 == 1; + } + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return src0_small && ne11 <= 4; + } + if (bf16_mma_hardware_available(cc)) { + return src0_small && ne11 <= 3; + } + return ne11 <= 8; + } else if (GGML_CUDA_CC_IS_AMD(cc)) { + if (bf16_mma_hardware_available(cc)) { + return ne11 <= 3; + } + return ne11 <= 8; + } + return ne11 <= 8; + default: + return false; + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh b/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh new file mode 100644 index 0000000..a50f7c0 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh @@ -0,0 +1,14 @@ +#include "common.cuh" + +#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels. + +void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion = nullptr); + +void ggml_cuda_op_mul_mat_vec_f( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); + +bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11); diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvq.cu b/llama.cpp/ggml/src/ggml-cuda/mmvq.cu new file mode 100644 index 0000000..ce25ccf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmvq.cu @@ -0,0 +1,767 @@ +#include "mmvq.cuh" +#include "quantize.cuh" +#include "unary.cuh" +#include "vecdotq.cuh" + +#include <cstdint> + +typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); + +static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; + case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; + case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; + case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; + case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; + case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; + case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; + case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; + case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1; + case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1; + case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1; + case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1; + case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1; + case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1; + case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1; + case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1; + case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1; + case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1; + case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1; + default: return nullptr; + } +} + +static constexpr __device__ int get_vdr_mmvq(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; + case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; + case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; + case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; + case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; + case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; + case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; + case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; + case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ; + case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ; + case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ; + case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ; + case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ; + case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ; + case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ; + default: return 1; + } +} + +enum mmvq_parameter_table_id { + MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_GCN, + MMVQ_PARAMETERS_RDNA2 +}; + +static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { +#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) + return MMVQ_PARAMETERS_RDNA2; +#elif defined(GCN) || defined(CDNA) + return MMVQ_PARAMETERS_GCN; +#else + return MMVQ_PARAMETERS_GENERIC; +#endif +} + +static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA2; + } + if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { + return MMVQ_PARAMETERS_GCN; + } + return MMVQ_PARAMETERS_GENERIC; +} + +static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } else if (table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + case 2: + case 3: + case 4: + return 2; + case 5: + case 6: + case 7: + case 8: + default: + return 1; + } + } + return 1; +} + +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + switch (ncols_dst) { + case 1: + return 1; + case 2: + case 3: + case 4: + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } + return 1; +} + +template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false> +__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const uint32_t ids_stride) { + + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int qi = ggml_cuda_type_traits<type>::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr mmvq_parameter_table_id table_id = get_device_table_id(); + constexpr int nwarps = calc_nwarps(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const int tid = warp_size*threadIdx.y + threadIdx.x; + const int row0 = rows_per_cuda_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + + uint32_t token_idx = 0; + uint32_t channel_x; + uint32_t channel_y; + uint32_t sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; + } + + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); + const uint32_t sample_y = sample_dst; + + bool use_gate = false; + bool use_bias = false; + bool use_gate_bias = false; + const void * vgate = nullptr; + const float * x_bias = nullptr; + const float * gate_bias = nullptr; + ggml_glu_op active_glu; + + if constexpr (has_fusion) { + use_gate = fusion.gate != nullptr; + use_bias = fusion.x_bias != nullptr; + use_gate_bias = fusion.gate_bias != nullptr && use_gate; + vgate = fusion.gate; + x_bias = (const float *) fusion.x_bias; + gate_bias = (const float *) fusion.gate_bias; + active_glu = fusion.glu_op; + } + + + float x_biases[ncols_dst] = { 0.0f }; + float gate_biases[ncols_dst] = { 0.0f }; + if constexpr (has_fusion) { + const uint32_t channel_bias = ids ? channel_x : channel_dst; + if (use_bias) { + x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + // 1. Hide latency by prefetching bias and gate here + // 2. load only on threads that won't die after partial sum calculation + if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 && + (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x]; + } + } + } + if (use_gate_bias) { + gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; + if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 && + (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x]; + } + } + } + } + + // partial sum for each thread + float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; + + const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y; + } + const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; + + for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx + + // x block quant index when casting the quants to int + const int kqs = vdr * (tid % (qi/vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp[j][i] += vec_dot_q_cuda( + vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += vec_dot_q_cuda( + vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + } + } + } + + __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; + if constexpr (!has_fusion) { + (void) tmp_shared_gate; + } else if (!use_gate) { + (void) tmp_shared_gate; + } + + if (threadIdx.y > 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { + tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i]; + } + } + } + } + } + __syncthreads(); + if (threadIdx.y > 0) { + return; + } + + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + + if constexpr (is_multi_token_id) { + dst += token_idx*stride_col_dst; + } + + // sum up partial sums and write back result +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int i = 0; i < rows_per_cuda_block; ++i) { +#pragma unroll + for (int l = 0; l < nwarps-1; ++l) { + tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x]; + } + } + } + tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]); + if constexpr (has_fusion) { + if (use_gate) { + tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]); + } + } + } + + if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { + float result = tmp[j][threadIdx.x]; + if constexpr (has_fusion) { + if (use_bias) { + result += x_biases[j]; + } + if (use_gate) { + float gate_value = tmp_gate[j][threadIdx.x]; + if (use_gate_bias) { + gate_value += gate_biases[j]; + } + switch (active_glu) { + case GGML_GLU_OP_SWIGLU: + result *= ggml_cuda_op_silu_single(gate_value); + break; + case GGML_GLU_OP_GEGLU: + result *= ggml_cuda_op_gelu_single(gate_value); + break; + case GGML_GLU_OP_SWIGLU_OAI: { + result = ggml_cuda_op_swiglu_oai_single(gate_value, result); + break; + } + default: + result = result * gate_value; + break; + } + } + } + dst[j*stride_col_dst + threadIdx.x] = result; + } + } + + if constexpr (!has_fusion) { + GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate); + } +} + +static std::pair<dim3, dim3> calc_launch_params( + const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, + const int warp_size, const mmvq_parameter_table_id table_id) { + const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); + const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); + return {block_nums, block_dims}; +} + +template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false> +static void mul_mat_vec_q_switch_fusion( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, + const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, + const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, + const uint32_t ids_stride, cudaStream_t stream) { + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + if constexpr (c_ncols_dst == 1) { + if (has_fusion) { + mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); + return; + } + } + + GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); + + mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>> + (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); +} + +template <ggml_type type> +static void mul_mat_vec_q_switch_ncols_dst( + const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride, cudaStream_t stream) { + + GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); + GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); + + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); + const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); + const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + return; + } + + switch (ncols_dst) { + case 1: { + constexpr int c_ncols_dst = 1; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 2: { + constexpr int c_ncols_dst = 2; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 3: { + constexpr int c_ncols_dst = 3; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 4: { + constexpr int c_ncols_dst = 4; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 5: { + constexpr int c_ncols_dst = 5; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 6: { + constexpr int c_ncols_dst = 6; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 7: { + constexpr int c_ncols_dst = 7; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + case 8: { + constexpr int c_ncols_dst = 8; + std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } break; + default: + GGML_ABORT("fatal error"); + break; + } + + GGML_UNUSED(has_fusion); +} +static void mul_mat_vec_q_switch_type( + const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, + const int ncols_x, const int nrows_x, const int ncols_dst, + const int stride_row_x, const int stride_col_y, const int stride_col_dst, + const int nchannels_x, const int nchannels_y, const int nchannels_dst, + const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride, cudaStream_t stream) { + switch (type_x) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_MXFP4: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q2_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q3_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q4_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q5_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_Q6_K: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ1_M: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ4_NL: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ4_XS: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, + const ggml_cuda_mm_fusion_args_host * fusion) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. + + GGML_TENSOR_BINARY_OP_LOCALS; + + cudaStream_t stream = ctx.stream(); + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT( nb00 == ts_src0); + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT( nb0 == ts_dst); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + + GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE); + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + ggml_cuda_mm_fusion_args_device fusion_local{}; + + if (fusion) { + GGML_ASSERT( !ids || dst->ne[2] == 1); + GGML_ASSERT( ids || dst->ne[1] == 1); + + if (fusion->x_bias) { + GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]); + fusion_local.x_bias = fusion->x_bias->data; + } + if (fusion->gate) { + GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0)); + fusion_local.gate = fusion->gate->data; + } + if (fusion->gate_bias) { + GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32); + GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]); + GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]); + fusion_local.gate_bias = fusion->gate_bias->data; + } + fusion_local.glu_op = fusion->glu_op; + } + + // If src0 is a temporary compute buffer, clear any potential padding. + if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + const size_t size_data = ggml_nbytes(src0); + const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0); + if (size_alloc > size_data) { + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + GGML_ASSERT(!src0->view_src); + CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream)); + } + } + + const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING); + ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1); + { + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s13 = src1->nb[3] / ts_src1; + quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); + } + + const int64_t s01 = src0->nb[1] / ts_src0; + const int64_t s11 = ne10_padded / QK8_1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t s02 = src0->nb[2] / ts_src0; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t s03 = src0->nb[3] / ts_src0; + const int64_t s3 = dst->nb[3] / ts_dst; + + const int64_t s12 = ne11*s11; + const int64_t s13 = ne12*s12; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + + mul_mat_vec_q_switch_type( + src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, + ne01, ncols_dst, s01, stride_col_y, stride_col_dst, + ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, + ne03, ne3, s03, s13, s3, ids_stride, stream); +} + +void ggml_cuda_op_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + const int64_t ne10 = src1->ne[0]; + GGML_ASSERT(ne10 % QK8_1 == 0); + + const int64_t ne0 = dst->ne[0]; + + int id = ggml_cuda_get_device(); + + // the main device has a larger memory buffer to hold the results from all GPUs + // nrows_dst == nrows of the matrix that the kernel writes into + const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; + + const int stride_row_x = ne00 / ggml_blck_size(src0->type); + const int stride_col_y = src1_padded_row_size / QK8_1; + + ggml_cuda_mm_fusion_args_device fusion_local{}; + mul_mat_vec_q_switch_type( + src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream); + + GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh b/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh new file mode 100644 index 0000000..4bb10cf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh @@ -0,0 +1,12 @@ +#include "common.cuh" + +#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. + +void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); + +void ggml_cuda_op_mul_mat_vec_q( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/llama.cpp/ggml/src/ggml-cuda/norm.cu b/llama.cpp/ggml/src/ggml-cuda/norm.cu new file mode 100644 index 0000000..ef98f67 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/norm.cu @@ -0,0 +1,672 @@ +#include "norm.cuh" +#include <cstdint> + +template <int block_size> +static __global__ void norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + float2 mean_var = make_float2(0.0f, 0.0f); + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + mean_var.x += xi; + mean_var.y += xi * xi; + } + + // sum up partial sums + extern __shared__ float2 s_sum2[]; + mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2); + + const float mean = mean_var.x / ncols; + const float var = mean_var.y / ncols - mean * mean; + const float inv_std = rsqrtf(var + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = (x[col] - mean) * inv_std; + } +} + +template <int block_size> +static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) { + // blockIdx.x: num_groups idx + // threadIdx.x: block_size idx + const int start = blockIdx.x*group_size + threadIdx.x; + const int end = min(blockIdx.x*group_size + group_size, ne_elements); + + float tmp = 0.0f; // partial sum for thread in warp + + for (int j = start; j < end; j += block_size) { + tmp += x[j]; + } + + extern __shared__ float s_sum[]; + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); + + const float mean = tmp / group_size; + tmp = 0.0f; + + for (int j = start; j < end; j += block_size) { + const float xi = x[j] - mean; + dst[j] = xi; + tmp += xi * xi; + } + + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); + + const float variance = tmp / group_size; + const float scale = rsqrtf(variance + eps); + for (int j = start; j < end; j += block_size) { + dst[j] *= scale; + } +} + +template <int block_size, bool do_multiply = false, bool do_add = false> +static __global__ void rms_norm_f32(const float * x, + float * dst, + const int ncols, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + const float * mul = nullptr, + const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, + const int64_t mul_stride_sample = 0, + const uint3 mul_ncols_packed = make_uint3(0, 0, 0), + const uint3 mul_nrows_packed = make_uint3(0, 0, 0), + const uint3 mul_nchannels_packed = make_uint3(0, 0, 0), + const uint3 mul_nsamples_packed = make_uint3(0, 0, 0), + const float * add = nullptr, + const int64_t add_stride_row = 0, + const int64_t add_stride_channel = 0, + const int64_t add_stride_sample = 0, + const uint3 add_ncols_packed = make_uint3(0, 0, 0), + const uint3 add_nrows_packed = make_uint3(0, 0, 0), + const uint3 add_nchannels_packed = make_uint3(0, 0, 0), + const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying"); + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + if constexpr (do_multiply) { + const uint32_t mul_row = fastmodulo(row, mul_nrows_packed); + const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed); + const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed); + mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row; + } + + if constexpr (do_add) { + const int add_row = fastmodulo(row, add_nrows_packed); + const int add_channel = fastmodulo(channel, add_nchannels_packed); + const int add_sample = fastmodulo(sample, add_nsamples_packed); + add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; + } + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + extern __shared__ float s_sum[]; + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + if constexpr (do_multiply && do_add) { + const int mul_col = fastmodulo(col, mul_ncols_packed); + const int add_col = fastmodulo(col, add_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; + } else if constexpr (do_multiply) { + const int mul_col = fastmodulo(col, mul_ncols_packed); + dst[col] = scale * x[col] * mul[mul_col]; + } else { + dst[col] = scale * x[col]; + } + } +} + +template <int block_size> +static __global__ void rms_norm_back_f32( + const float * grad, const float * xf, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + grad += int64_t(row)*ncols; + xf += int64_t(row)*ncols; + dst += int64_t(row)*ncols; + + float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass + float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + + for (int col = tid; col < ncols; col += block_size) { + const float xfi = xf[col]; + sum_xx += xfi * xfi; + sum_xg += xfi * grad[col]; + } + + // sum up partial sums + sum_xx = warp_reduce_sum(sum_xx); + sum_xg = warp_reduce_sum(sum_xg); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum_xx[32]; + __shared__ float s_sum_xg[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum_xx[warp_id] = sum_xx; + s_sum_xg[warp_id] = sum_xg; + } + __syncthreads(); + + sum_xx = s_sum_xx[lane_id]; + sum_xx = warp_reduce_sum(sum_xx); + + sum_xg = s_sum_xg[lane_id]; + sum_xg = warp_reduce_sum(sum_xg); + } + + const float mean_eps = sum_xx / ncols + eps; + const float sum_eps = sum_xx + ncols*eps; + + const float scale_grad = rsqrtf(mean_eps); + const float scale_x = -scale_grad * sum_xg/sum_eps; + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale_grad*grad[col] + scale_x*xf[col]; + } +} + +// template <int block_size> +// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) { +// const int row = blockIdx.x*blockDim.y + threadIdx.y; +// const int tid = threadIdx.x; + +// float tmp = 0.0f; // partial sum for thread in warp + +// for (int col = tid; col < ncols; col += block_size) { +// const float xi = x[row*ncols + col]; +// tmp += xi * xi; +// } + +// // sum up partial sums +// tmp = warp_reduce_sum(tmp); +// if (block_size > WARP_SIZE) { +// __shared__ float s_sum[32]; +// int warp_id = threadIdx.x / WARP_SIZE; +// int lane_id = threadIdx.x % WARP_SIZE; +// if (lane_id == 0) { +// s_sum[warp_id] = tmp; +// } +// __syncthreads(); +// tmp = s_sum[lane_id]; +// tmp = warp_reduce_sum(tmp); +// } + +// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html +// const float scale = rsqrtf(fmaxf(tmp, eps * eps)); + +// for (int col = tid; col < ncols; col += block_size) { +// dst[row*ncols + col] = scale * x[row*ncols + col]; +// } +// } + +template <int block_size> +static __global__ void l2_norm_f32( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + extern __shared__ float s_sum[]; + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); + + // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html + const float scale = rsqrtf(fmaxf(tmp, eps * eps)); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * x[col]; + } +} + +static void norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + +static void group_norm_f32_cuda( + const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) { + if (group_size < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps); + } else { + const dim3 block_dims(1024, 1, 1); + group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps); + } +} + +static void rms_norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + +static void rms_norm_mul_f32_cuda(const float * x, + const float * mul, + const float * add, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const int64_t mul_stride_row, + const int64_t mul_stride_channel, + const int64_t mul_stride_sample, + const uint32_t mul_ncols, + const uint32_t mul_nrows, + const uint32_t mul_nchannels, + const uint32_t mul_nsamples, + const int64_t add_stride_row, + const int64_t add_stride_channel, + const int64_t add_stride_sample, + const uint32_t add_ncols, + const uint32_t add_nrows, + const uint32_t add_nchannels, + const uint32_t add_nsamples, + const float eps, + cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (mul == nullptr) { + rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); + return; + } + if (add == nullptr) { + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + } + } else { + const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); + const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows); + const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels); + const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); + + const uint3 add_ncols_packed = init_fastdiv_values(add_ncols); + const uint3 add_nrows_packed = init_fastdiv_values(add_nrows); + const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels); + const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); + if (ncols < 1024) { + const dim3 block_dims(256, 1, 1); + rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, + add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, + add_nchannels_packed, add_nsamples_packed); + } + } +} + +static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps); + } +} + +static void l2_norm_f32_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + +void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_UNARY_OP_LOCALS; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); +} + +void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + int num_groups = dst->op_params[0]; + + float eps; + memcpy(&eps, dst->op_params + 1, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream); +} + +void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_UNARY_OP_LOCALS; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); +} + +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; + + if (mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; + } else if(mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) mul_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d, + ne00, ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ 0, 0, 0, + 0, 0, 0, 0, + eps, stream); +} + +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; + + if (mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + const float * add_d = nullptr; + const ggml_tensor * add_src = nullptr; + + if (add_tensor->src[0] == mul_tensor) { + add_d = (float *) add_tensor->src[1]->data; + add_src = add_tensor->src[1]; + } else if (add_tensor->src[1] == mul_tensor) { + add_d = (float *) add_tensor->src[0]->data; + add_src = add_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) add_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(add_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + const size_t ts_add = ggml_type_size(add_src->type); + GGML_ASSERT(add_src->nb[0] == ts_add); + const int64_t add_s01 = add_src->nb[1] / ts_add; + const int64_t add_s02 = add_src->nb[2] / ts_add; + const int64_t add_s03 = add_src->nb[3] / ts_add; + + const int add_ncols = add_src->ne[0]; + const int add_nrows = add_src->ne[1]; + const int add_nchannels = add_src->ne[2]; + const int add_nsamples = add_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d, + ne00,ne01, ne02, ne03, + /*s00*/ s01, s02, s03, + /*mul_s00*/ mul_s01, mul_s02, mul_s03, + mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, + /*add_s00*/ add_s01, add_s02, add_s03, + add_ncols, add_nrows, add_nchannels, add_nsamples, + eps, stream); +} + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * grad = dst->src[0]; // gradients + const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass + + const float * grad_d = (const float *) grad->data; + const float * src0f_d = (const float *) src0f->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(grad)); + + GGML_ASSERT( grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0f->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0f->ne[0]; + const int64_t nrows = ggml_nrows(src0f); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream); +} + +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_UNARY_OP_LOCALS; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); + + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + + l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/norm.cuh b/llama.cpp/ggml/src/ggml-cuda/norm.cuh new file mode 100644 index 0000000..a74f637 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/norm.cuh @@ -0,0 +1,18 @@ +#include "common.cuh" + +void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); + +void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + ggml_tensor * mul_tensor, + ggml_tensor * add_tensor); + +void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cu b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cu new file mode 100644 index 0000000..35154f2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cu @@ -0,0 +1,78 @@ +#include "ggml-impl.h" +#include "opt-step-adamw.cuh" + +#include <cstdint> + +static __global__ void opt_step_adamw_f32( + float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, + const float * __restrict__ pars, const int64_t k) { + + const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[i]; + const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1); + const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2); + + g_m[i] = gmi; + g_v[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrtf(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} + +static void opt_step_adamw_f32_cuda( + float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) { + + const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); + const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); + opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, pars, k); +} + +void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * src0_grad_m = dst->src[2]; + const ggml_tensor * src0_grad_v = dst->src[3]; + const ggml_tensor * adamw_params = dst->src[4]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src0_grad)); + GGML_ASSERT(ggml_is_contiguous(src0_grad_m)); + GGML_ASSERT(ggml_is_contiguous(src0_grad_v)); + GGML_ASSERT(ggml_is_contiguous(adamw_params)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); + + float * src0_d = (float *) src0->data; + const float * src0_grad_d = (const float *) src0_grad->data; + float * src0_grad_m_d = (float *) src0_grad_m->data; + float * src0_grad_v_d = (float *) src0_grad_v->data; + const float * adamw_params_d = (const float *) adamw_params->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t ne = ggml_nelements(src0); + + opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cuh b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cuh new file mode 100644 index 0000000..58d6f6e --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256 + +void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu new file mode 100644 index 0000000..460b16d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu @@ -0,0 +1,49 @@ +#include "ggml-impl.h" +#include "opt-step-sgd.cuh" + +#include <cstdint> + +static __global__ void opt_step_sgd_f32( + float * __restrict__ x, const float * __restrict__ g, + const float * __restrict__ pars, const int64_t k) { + + const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; + + if (i >= k) { + return; + } + x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i]; +} + +static void opt_step_sgd_f32_cuda( + float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) { + + const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1); + const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1); + opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k); +} + +void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * params = dst->src[2]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); + GGML_ASSERT(params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src0_grad)); + GGML_ASSERT(ggml_is_contiguous(params)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_nelements(params) == 2); + + float * src0_d = (float *) src0->data; + const float * src0_grad_d = (const float *) src0_grad->data; + const float * params_d = (const float *) params->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t ne = ggml_nelements(src0); + + opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh new file mode 100644 index 0000000..f97ab7d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256 + +void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/out-prod.cu b/llama.cpp/ggml/src/ggml-cuda/out-prod.cu new file mode 100644 index 0000000..c9b2b69 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/out-prod.cu @@ -0,0 +1,68 @@ +#include "out-prod.cuh" + +#include <cstdint> + +void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ne01 == ne11); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + + GGML_ASSERT(ne2 % src0->ne[2] == 0); + GGML_ASSERT(ne3 % src0->ne[3] == 0); + + GGML_ASSERT(ne2 == src1->ne[2]); + GGML_ASSERT(ne3 == src1->ne[3]); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + cublasHandle_t handle = ctx.cublas_handle(); + + const float alpha = 1.0f; + const float beta = 0.0f; + + CUBLAS_CHECK(cublasSetStream(handle, stream)); + + const int64_t lda = nb01 / sizeof(float); + const int64_t ldc = nb1 / sizeof(float); + + const bool src1_T = ggml_is_transposed(src1); + const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); + GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); + + // data strides in dimensions 2/3 + const size_t s02 = nb02 / sizeof(float); + const size_t s03 = nb03 / sizeof(float); + const size_t s12 = nb12 / sizeof(float); + const size_t s13 = nb13 / sizeof(float); + const size_t s2 = nb2 / sizeof(float); + const size_t s3 = nb3 / sizeof(float); + + // dps == dst per src0, used for group query attention + const int64_t dps2 = ne2 / ne02; + const int64_t dps3 = ne3 / ne03; + + // TODO batched matrix multiplication + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + } + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/out-prod.cuh b/llama.cpp/ggml/src/ggml-cuda/out-prod.cuh new file mode 100644 index 0000000..a0046f5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/out-prod.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/pad.cu b/llama.cpp/ggml/src/ggml-cuda/pad.cu new file mode 100644 index 0000000..31cd00f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/pad.cu @@ -0,0 +1,106 @@ +#include "pad.cuh" + +#include <stdint.h> + +__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) { + // + size ensures negatives are handled properly + return (coord + size) % size; +} + +static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3, + const bool circular) { + // blockIdx.z: i3*ne2+i2 + // blockIdx.y: i1 + // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE + // gridDim.y: ne1 + int i0 = threadIdx.x + blockIdx.x * blockDim.x; + int i1 = blockIdx.y; + int i2 = blockIdx.z % ne2; + int i3 = blockIdx.z / ne2; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0; + + if (!circular) { + if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3)) { + const int64_t i00 = i0 - lp0; + const int64_t i01 = i1 - lp1; + const int64_t i02 = i2 - lp2; + const int64_t i03 = i3 - lp3; + + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; + + dst[dst_idx] = src[src_idx]; + } else { + dst[dst_idx] = 0.0f; + } + } + // circular means on a torus, so x and y wrap around + else { + const int64_t ne00 = ne0 - lp0 - rp0; + const int64_t ne01 = ne1 - lp1 - rp1; + const int64_t ne02 = ne2 - lp2 - rp2; + const int64_t ne03 = ne3 - lp3 - rp3; + + const int64_t i00 = wrap_around(i0 - lp0, ne00); + const int64_t i01 = wrap_around(i1 - lp1, ne01); + const int64_t i02 = wrap_around(i2 - lp2, ne02); + const int64_t i03 = wrap_around(i3 - lp3, ne03); + + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; + + dst[dst_idx] = src[src_idx]; + } +} + + +static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, + const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3, + const bool circular, cudaStream_t stream) { + int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne1, ne2 * ne3); + pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst, + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + ne0, ne1, ne2, ne3, circular); +} + +void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int32_t lp0 = ((const int32_t *) (dst->op_params))[0]; + const int32_t rp0 = ((const int32_t *) (dst->op_params))[1]; + const int32_t lp1 = ((const int32_t *) (dst->op_params))[2]; + const int32_t rp1 = ((const int32_t *) (dst->op_params))[3]; + const int32_t lp2 = ((const int32_t *) (dst->op_params))[4]; + const int32_t rp2 = ((const int32_t *) (dst->op_params))[5]; + const int32_t lp3 = ((const int32_t *) (dst->op_params))[6]; + const int32_t rp3 = ((const int32_t *) (dst->op_params))[7]; + const int32_t circular = ((const int32_t *) (dst->op_params))[8]; + + const size_t s00 = nb00 / ggml_type_size(src0->type); + const size_t s01 = nb01 / ggml_type_size(src0->type); + const size_t s02 = nb02 / ggml_type_size(src0->type); + const size_t s03 = nb03 / ggml_type_size(src0->type); + + pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d, + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (bool) circular, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/pad.cuh b/llama.cpp/ggml/src/ggml-cuda/pad.cuh new file mode 100644 index 0000000..8fd386b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/pad.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_PAD_BLOCK_SIZE 256 + +void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu new file mode 100644 index 0000000..32993eb --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu @@ -0,0 +1,91 @@ +#include "pad_reflect_1d.cuh" + +static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void + pad_reflect_1d_kernel_f32( + const void * __restrict__ src0, + void * __restrict__ dst, + const int64_t ne0, + const int64_t ne00, + const uint3 ne01, + const int64_t ne02, + const int64_t ne03, + const int64_t nb00, + const int64_t nb01, + const int64_t nb02, + const int64_t nb03, + const int64_t nb0, + const int64_t nb1, + const int64_t nb2, + const int64_t nb3, + const int p0, + const int p1) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + + const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01); + const int64_t tile1 = div_mod_packed.y; // i1 + const int64_t tile0 = div_mod_packed.x; // nth i0 tile + const int64_t i1 = tile1; + const int64_t i0 = threadIdx.x + tile0 * blockDim.x; + + // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh) + if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) { + return; + } + + const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; + char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1; + + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 + int64_t src_idx; + + if (rel_i0 < 0) { + // Left padding - reflect + src_idx = -rel_i0; + } else if (rel_i0 < ne00) { + // Middle - copy + src_idx = rel_i0; + } else { + // Right padding - reflect + src_idx = 2 * ne00 - 2 - rel_i0; + } + const float value = *(const float *) (src0_ptr + src_idx * nb00); + *(float *) (dst_ptr + i0 * nb0) = value; + + GGML_UNUSED(p1); +} + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int32_t * opts = (const int32_t *) dst->op_params; + const int p0 = opts[0]; + const int p1 = opts[1]; + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const uint3 ne01_packed = init_fastdiv_values(ne01); + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne0 = dst->ne[0]; + + // sanity: padded length matches + GGML_ASSERT(ne0 == ne00 + p0 + p1); + + constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) + const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 + // grid.x covers i1 and all tiles of i0: [ne01 * tiles0] + // grid.y covers i2: [ne02] + // grid.z covers i3: [ne03] + const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03); + const dim3 block_dims((unsigned) bx, 1, 1); + + pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>( + src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh new file mode 100644 index 0000000..15f2ed1 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256 + +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/pool2d.cu b/llama.cpp/ggml/src/ggml-cuda/pool2d.cu new file mode 100644 index 0000000..c6d51e4 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/pool2d.cu @@ -0,0 +1,94 @@ +#include "pool2d.cuh" + +template <typename Ti, typename To> +static __global__ void pool2d_nchw_kernel( + const int ih, const int iw, const int oh, const int ow, + const int kh, const int kw, const int sh, const int sw, + const int ph, const int pw, const int parallel_elements, + const Ti* src, To* dst, const enum ggml_op_pool op) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= parallel_elements) { + return; + } + + const int I_HW = ih * iw; + const int O_HW = oh * ow; + const int nc = idx / O_HW; + const int cur_oh = idx % O_HW / ow; + const int cur_ow = idx % O_HW % ow; + const Ti* i_ptr = src + nc * I_HW; + To* o_ptr = dst + nc * O_HW; + const int start_h = cur_oh * sh - ph; + const int bh = max(0, start_h); + const int eh = min(ih, start_h + kh); + const int start_w = cur_ow * sw - pw; + const int bw = max(0, start_w); + const int ew = min(iw, start_w + kw); + const To scale = 1. / (kh * kw); + To res = 0; + + switch (op) { + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; + default: assert(false); + } + + for (int i = bh; i < eh; i += 1) { + for (int j = bw; j < ew; j += 1) { +#if __CUDA_ARCH__ >= 350 + Ti cur = __ldg(i_ptr + i * iw + j); +#else + Ti cur = i_ptr[i * iw + j]; +#endif + switch (op) { + case GGML_OP_POOL_AVG: res += cur * scale; break; + case GGML_OP_POOL_MAX: res = max(res, (To)cur); break; + default: assert(false); + } + } + } + o_ptr[cur_oh * ow + cur_ow] = res; +} + +static void pool2d_nchw_kernel_f32_f32_cuda( + const int ih, const int iw, const int oh, const int ow, + const int kh, const int kw, const int sh, const int sw, + const int ph, const int pw, const int parallel_elements, + const float * src, float * dst, const enum ggml_op_pool op, + cudaStream_t stream) { + + const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE; + dim3 block_nums(num_blocks); + pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op); +} + +void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int32_t * opts = (const int32_t *)dst->op_params; + enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]); + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; + + const int64_t IH = src0->ne[1]; + const int64_t IW = src0->ne[0]; + + const int64_t N = dst->ne[3]; + const int64_t OC = dst->ne[2]; + const int64_t OH = dst->ne[1]; + const int64_t OW = dst->ne[0]; + + const int parallel_elements = N * OC * OH * OW; + + pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/pool2d.cuh b/llama.cpp/ggml/src/ggml-cuda/pool2d.cuh new file mode 100644 index 0000000..7841292 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/pool2d.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_POOL2D_BLOCK_SIZE 256 + +void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/quantize.cu b/llama.cpp/ggml/src/ggml-cuda/quantize.cu new file mode 100644 index 0000000..a8c68e4 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/quantize.cu @@ -0,0 +1,343 @@ +#include "quantize.cuh" +#include <cstdint> + +__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) +static __global__ void quantize_q8_1( + const float * __restrict__ x, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const uint32_t ne1, const uint3 ne2) { + const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; + + if (i0 >= ne0) { + return; + } + + const int64_t i3 = fastdiv(blockIdx.z, ne2); + const int64_t i2 = blockIdx.z - i3*ne2.z; + const int64_t i1 = blockIdx.y; + + const int64_t & i00 = i0; + const int64_t & i01 = i1; + const int64_t & i02 = i2; + const int64_t & i03 = i3; + + const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0; + + block_q8_1 * y = (block_q8_1 *) vy; + + const int64_t ib = i_cont / QK8_1; // block index + const int64_t iqs = i_cont % QK8_1; // quant index + + const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; + float amax = fabsf(xi); + float sum = xi; + + amax = warp_reduce_max<QK8_1>(amax); + sum = warp_reduce_sum<QK8_1>(sum); + + const float d = amax / 127.0f; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + y[ib].ds = make_half2(d, sum); +} + +__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { + if (!(amax > 0.0f)) { + return 0; + } + + // FP4 E2M1: max exponent (unbiased) is 2. + constexpr int FP4_E2M1_EMAX = 2; + + const float e = log2f(amax); + + // "even" -> round-to-nearest integer, ties-to-even + const int e_int = __float2int_rn(e); + + const int shared_exp = e_int - FP4_E2M1_EMAX; + + int biased = shared_exp + 127; + + biased = max(biased, 0); + biased = min(biased, 254); + + return static_cast<uint8_t>(biased); +} + +// quantize values in the format mxfp4 is stored which is interleaved nibbles +// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31 +static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, + const int32_t * __restrict__ ids, + void * __restrict__ vy, + const int64_t ne00, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t ne0, + const int ne1, + const int ne2) { + constexpr int vals_per_scale = 32; + constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values + + const int warp_id = threadIdx.y; + const int lane_id_32 = threadIdx.x; + + const int nwarps = blockDim.y; + + const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp; + + if (warp_start_offset >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + block_fp4_mmq * y = (block_fp4_mmq *) vy; + + const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values + const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size)); + const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x; + const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp; + + const int group_id = lane_id_32 / 4; + const int lane_in_group = lane_id_32 % 4; + const int base = group_id * 2; + char2 * yqs2 = (char2 *) y[ib].qs; + + const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01; + + uint8_t scales[2]; + +#pragma unroll + for (int b = 0; b < 2; ++b) { + const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32; + const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f; + + float amax = fabsf(xi); +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE)); + } + + const uint8_t e = compute_e8m0_scale(amax); + scales[b] = e; + const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e)); + +#if CUDART_VERSION >= 12080 + const float scaled_val = xi * inv_s; + + const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE); + const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE); + const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE); + const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE); + + if (lane_in_group == 0) { + __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3)); + + yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed; + } +#else + // Fallback: manual FP4 conversion using LUT + const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s); + + const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE); + const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE); + const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE); + const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE); + + if (lane_in_group == 0) { + char2 q; + q.x = (q_hi_0 << 4) | q_lo_0; + q.y = (q_hi_1 << 4) | q_lo_1; + yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q; + } +#endif // CUDART_VERSION >= 12080 + } + + if (lane_id_32 == 0) { + // Store 2 scales packed into 1 uint32 + y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0]; + } +} + +template <mmq_q8_1_ds_layout ds_layout> +static __global__ void quantize_mmq_q8_1( + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int ne1, const int ne2) { + + constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; + constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; + + const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4; + + if (i0 >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + + const int64_t i00 = i0; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t i02 = i2; + const int64_t i03 = i3; + + const float4 * x4 = (const float4 *) x; + + block_q8_1_mmq * y = (block_q8_1_mmq *) vy; + + const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel + const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel + const int64_t iqs = i0 % (4*QK8_1); // quant index in block + + // Load 4 floats per thread and calculate max. abs. value between them: + const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); + float amax = fabsf(xi.x); + amax = fmaxf(amax, fabsf(xi.y)); + amax = fmaxf(amax, fabsf(xi.z)); + amax = fmaxf(amax, fabsf(xi.w)); + + // Exchange max. abs. value between vals_per_scale/4 threads. +#pragma unroll + for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); + } + + float sum; + if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { + sum = xi.x + xi.y + xi.z + xi.w; + + // Calculate sums across vals_per_sum/4 threads. +#pragma unroll + for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { + sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); + } + } + + const float d_inv = 127.0f / amax; + char4 q; + q.x = roundf(xi.x*d_inv); + q.y = roundf(xi.y*d_inv); + q.z = roundf(xi.z*d_inv); + q.w = roundf(xi.w*d_inv); + + // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + char4 * yqs4 = (char4 *) y[ib].qs; + yqs4[iqs/4] = q; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { + if (iqs % 16 != 0 || iqs >= 96) { + return; + } + + y[ib].d2s6[2 + iqs/16] = sum; + + if (iqs % 64 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + y[ib].d2s6[iqs/64] = d; + + return; + } + + if (iqs % 32 != 0) { + return; + } + + const float d = 1.0f / d_inv; + + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { + y[ib].ds4[iqs/32] = make_half2(d, sum); + } else { + y[ib].d4[iqs/32] = d; + } +} + +void quantize_row_q8_1_cuda( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(!ids); + GGML_ASSERT(ne0 % QK8_1 == 0); + + const uint3 ne2_fastdiv = init_fastdiv_values(ne2); + + const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + const dim3 num_blocks(block_num_x, ne1, ne2*ne3); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); + quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); + GGML_UNUSED(type_src0); +} + +void quantize_mmq_q8_1_cuda( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(ne0 % (4*QK8_1) == 0); + + // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid: + const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); + const dim3 num_blocks(ne1, block_num_y, ne2*ne3); + const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); + switch (mmq_get_q8_1_ds_layout(type_src0)) { + case MMQ_Q8_1_DS_LAYOUT_D4: + quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4> + <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + case MMQ_Q8_1_DS_LAYOUT_DS4: + quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4> + <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + case MMQ_Q8_1_DS_LAYOUT_D2S6: + quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6> + <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void quantize_mmq_mxfp4_cuda(const float * x, + const int32_t * ids, + void * vy, + [[maybe_unused]] const ggml_type type_src0, + const int64_t ne00, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + cudaStream_t stream) { + GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); + + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; + constexpr int vals_per_block = nwarps * vals_per_warp; + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + const dim3 block_size(WARP_SIZE, nwarps, 1); + + quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/quantize.cuh b/llama.cpp/ggml/src/ggml-cuda/quantize.cuh new file mode 100644 index 0000000..6a91df6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/quantize.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "common.cuh" +#include "mmq.cuh" + +#include <cstdint> + +#define CUDA_QUANTIZE_BLOCK_SIZE 256 +#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128 + +static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access."); +static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); + +typedef void (*quantize_cuda_t)( + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); + +void quantize_row_q8_1_cuda( + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); + +void quantize_mmq_q8_1_cuda( + const float * x, const int32_t * ids, void * vy, + ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, + int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); + +void quantize_mmq_mxfp4_cuda(const float * x, + const int32_t * ids, + void * vy, + ggml_type type_src0, + int64_t ne00, + int64_t s01, + int64_t s02, + int64_t s03, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + cudaStream_t stream); diff --git a/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh b/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh new file mode 100644 index 0000000..de240fd --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh @@ -0,0 +1,39 @@ +#include "common.cuh" + +// Row reduction kernel template - compute sum (norm=false) or mean (norm=true) +template <bool norm> +static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { + const int row = blockIdx.x; + const int col = threadIdx.x; + + float sum = 0.0f; + const int num_unroll = 8; + float temp[num_unroll]; + float sum_temp[num_unroll] = { 0.0f }; + for (int i = col; i < ncols;) { + for (int j = 0; j < num_unroll; ++j) { + if (i < ncols) { + temp[j] = x[row * ncols + i]; + } else { + temp[j] = 0; + } + i += blockDim.x; + } + for (int j = 0; j < num_unroll; ++j) { + sum_temp[j] += temp[j]; + } + } + for (int j = 0; j < num_unroll; ++j) { + sum += sum_temp[j]; + } + + // sum up partial sums + __shared__ float shared_vals[32]; + sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals); + + if (col != 0) { + return; + } + + dst[row] = norm ? sum / ncols : sum; +} diff --git a/llama.cpp/ggml/src/ggml-cuda/roll.cu b/llama.cpp/ggml/src/ggml-cuda/roll.cu new file mode 100644 index 0000000..a339dfc --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/roll.cu @@ -0,0 +1,67 @@ +#include "ggml-cuda/common.cuh" +#include "roll.cuh" + +static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) { + if (idx < 0) { + return idx + ne; + } + if (idx >= ne) { + return idx - ne; + } + return idx; +} + +static __global__ void roll_f32_cuda(const float * __restrict__ src, + float * __restrict__ dst, + const int64_t ne00, + const int64_t ne01, + const int64_t ne02, + const int64_t ne03, + const int s0, + const int s1, + const int s2, + const int s3) { + const int64_t idx = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const int64_t n_elements = ne00 * ne01 * ne02 * ne03; + + if (idx >= n_elements) { + return; + } + + const int64_t i0 = idx % ne00; + const int64_t i1 = (idx / ne00) % ne01; + const int64_t i2 = (idx / (ne00 * ne01)) % ne02; + const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03; + + const int64_t d0 = wrap_index(i0 - s0, ne00); + const int64_t d1 = wrap_index(i1 - s1, ne01); + const int64_t d2 = wrap_index(i2 - s2, ne02); + const int64_t d3 = wrap_index(i3 - s3, ne03); + + dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] = + src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0]; +} + +void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + int s0 = dst->op_params[0]; + int s1 = dst->op_params[1]; + int s2 = dst->op_params[2]; + int s3 = dst->op_params[3]; + + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) dst->src[0]->data; + float * dst_d = (float *) dst->data; + + GGML_TENSOR_UNARY_OP_LOCALS; + + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst)); + + cudaStream_t stream = ctx.stream(); + + int64_t sz = (ne00 * ne01 * ne02 * ne03); + int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE; + + roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>( + src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/roll.cuh b/llama.cpp/ggml/src/ggml-cuda/roll.cuh new file mode 100644 index 0000000..322d554 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/roll.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_ROLL_BLOCK_SIZE 256 + +void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/rope.cu b/llama.cpp/ggml/src/ggml-cuda/rope.cu new file mode 100644 index 0000000..45a49a5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/rope.cu @@ -0,0 +1,665 @@ +#include "convert.cuh" +#include "ggml-cuda/common.cuh" +#include "ggml.h" +#include "rope.cuh" + +struct rope_corr_dims { + float v[2]; +}; + + +struct mrope_sections { + int v[4]; +}; + +static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +template<bool forward> +static __device__ void rope_yarn( + const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor, + float mscale, float & cos_theta, float & sin_theta) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + cos_theta = cosf(theta) * mscale; + sin_theta = sinf(theta) * mscale; + if (!forward) { + sin_theta *= -1.0f; + } +} + +template <bool forward, bool has_ff, typename T, typename D> +static __global__ void rope_norm(const T * x, + D * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne00) { + return; + } + + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. + if (set_rows_stride != 0) { + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; + } + + const auto & store_coaelsced = [&](float x0, float x1) { + if constexpr (std::is_same_v<float, D>) { + float2 v = make_float2(x0, x1); + ggml_cuda_memcpy_1<8>(dst + idst, &v); + } else if constexpr (std::is_same_v<half, D>) { + half2 v = make_half2(x0, x1); + ggml_cuda_memcpy_1<4>(dst + idst, &v); + } + }; + if (i0 >= n_dims) { + store_coaelsced(x[ix + 0], x[ix + 1]); + return; + } + + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float x0 = x[ix + 0]; + const float x1 = x[ix + 1]; + + store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta); +} + +template <bool forward, bool has_ff, typename T, typename D> +static __global__ void rope_neox(const T * x, + D * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne00) { + return; + } + + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. + if (set_rows_stride != 0) { + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; + } + + if (i0 >= n_dims) { + dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]); + dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]); + + return; + } + + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; + + dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta); + dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta); +} + +template <bool forward, bool has_ff, typename T> +static __global__ void rope_multi(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope) { + const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y); + + if (i0 >= ne00) { + return; + } + + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + + if (i0 >= n_dims) { + dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; + dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + + return; + } + + const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else { + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); + } + } else { + if (sector < sections.v[0]) { + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); + } + } + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims/2]; + + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; +} + +template <bool forward, bool has_ff, typename T> +static __global__ void rope_vision(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections) { + const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); + + if (i0 >= ne00) { + return; + } + + const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + + const int sect_dims = sections.v[0] + sections.v[1]; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < sections.v[0]) { + const int p = sector; + theta_base = pos[i2] * powf(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { + const int p = sector - sections.v[0]; + theta_base = pos[i2 + ne02] * powf(theta_scale, p); + } + + const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; + + float cos_theta; + float sin_theta; + + rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta); + + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims]; + + dst[idst + 0] = x0*cos_theta - x1*sin_theta; + dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; +} + +template <bool forward, typename T, typename D> +static void rope_norm_cuda(const T * x, + D * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + if (freq_factors == nullptr) { + rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); + } else { + rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); + } +} + +template <bool forward, typename T, typename D> +static void rope_neox_cuda(const T * x, + D * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + if (freq_factors == nullptr) { + rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); + } else { + rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); + } +} + +template <bool forward, typename T> +static void rope_multi_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + if (freq_factors == nullptr) { + rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); + } else { + rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); + } +} + +template <bool forward, typename T> +static void rope_vision_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); + const dim3 block_nums(nr, n_blocks_x, 1); + // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) + // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + if (freq_factors == nullptr) { + rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } else { + rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, sections); + } +} + +template <bool forward> +void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + const ggml_tensor * set_rows = nullptr) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + + void * dst_d = dst->data; + const int64_t * row_indices = nullptr; + ggml_type dst_type = dst->type; + int set_rows_stride = 0; + + if (set_rows != nullptr) { + GGML_ASSERT(forward); + dst_d = set_rows->data; + row_indices = (const int64_t *) set_rows->src[1]->data; + dst_type = set_rows->type; + set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type); + } + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + // When not fused, src0 and dst types must match + // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16 + GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16)); + + const int64_t ne00 = src0->ne[0]; // head dims + const int64_t ne01 = src0->ne[1]; // num heads + const int64_t ne02 = src0->ne[2]; // num heads + const int64_t nr = ggml_nrows(src0); + + const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); + const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); + + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + //const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + mrope_sections sections; + + // RoPE alteration for extended context + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } + + const int32_t * pos = (const int32_t *) src1_d; + + const float * freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *) src2->data; + } + + rope_corr_dims corr_dims; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); + + // compute + if (is_neox) { + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else { + GGML_ABORT("fatal error"); + } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32) { + rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); + } else { + GGML_ABORT("fatal error"); + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32) { + rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); + } else { + GGML_ABORT("fatal error"); + } + } else { + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else { + GGML_ABORT("fatal error"); + } + } +} + +void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl<true>(ctx, dst); +} + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_rope_impl<false>(ctx, dst); +} + +void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) { + ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/rope.cuh b/llama.cpp/ggml/src/ggml-cuda/rope.cuh new file mode 100644 index 0000000..72af086 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/rope.cuh @@ -0,0 +1,9 @@ +#include "common.cuh" + +#define CUDA_ROPE_BLOCK_SIZE 256 + +void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows); diff --git a/llama.cpp/ggml/src/ggml-cuda/scale.cu b/llama.cpp/ggml/src/ggml-cuda/scale.cu new file mode 100644 index 0000000..0ddeff6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/scale.cu @@ -0,0 +1,34 @@ +#include "scale.cuh" + +#define MAX_GRIDDIM_X 0x7FFFFFFF + +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; + int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + + for (int64_t i = tid; i < nelements; i += stride) { + dst[i] = scale * x[i] + bias; + } +} + +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { + const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements); +} + +void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float scale; + float bias; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); + + scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/scale.cuh b/llama.cpp/ggml/src/ggml-cuda/scale.cuh new file mode 100644 index 0000000..8ff75c8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/scale.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_SCALE_BLOCK_SIZE 256 + +void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/set-rows.cu b/llama.cpp/ggml/src/ggml-cuda/set-rows.cu new file mode 100644 index 0000000..631de7e --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/set-rows.cu @@ -0,0 +1,330 @@ +#include "set-rows.cuh" +#include "cpy-utils.cuh" + +typedef void (*set_rows_kernel_t)(const char * src, char * dst); + +// Generic quantized set_rows kernel template +template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)> +static __global__ void k_set_rows_quant(const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_type * __restrict__ dst, + const int64_t ne_total, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3, + const uint3 ne00, + const uint3 ne01, + const uint3 ne02, + const uint3 ne11_fd, + const uint3 ne12_fd) { + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + + if (i >= ne_total) { + return; + } + + const int64_t i_base = i * qk; + uint32_t tmp = (uint32_t) i_base; + uint2 div_mod; + + div_mod = fast_div_modulo(tmp, ne00); + const int64_t i00 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne01); + const int64_t i01 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne02); + const int64_t i02 = div_mod.y; + const int64_t i03 = div_mod.x; + + const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd); + const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type); + + const float * src_block = src0_row + i00; + block_type * dst_block = dst_row_ptr + i00 / qk; + + quantize_func(src_block, dst_block); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne11); + GGML_UNUSED(ne12); + GGML_UNUSED(ne13); +} + +// Template dispatch function for quantized set_rows +template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)> +static void set_rows_cuda_quant( + const float * src0_d, const idx_t * src1_d, block_type * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + GGML_ASSERT(ne00 % qk == 0); + const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + const int64_t s1 = nb1; + const int64_t s2 = nb2; + const int64_t s3 = nb3; + + if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) { + const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00); + const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01); + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); + const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); + + k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>( + src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, + ne01_fd, ne02_fd, ne11_fd, ne12_fd); + } +} + +template <typename src_t, typename idx_t, typename dst_t> +static __global__ void k_set_rows(const src_t * __restrict__ src0, + const idx_t * __restrict__ src1, + dst_t * __restrict__ dst, + const int64_t ne_total, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3, + const uint3 ne00, + const uint3 ne01, + const uint3 ne02, + const uint3 ne11_fd, + const uint3 ne12_fd) { + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + + if (i >= ne_total) { + return; + } + + uint32_t tmp = (uint32_t) i; + uint2 div_mod; + + div_mod = fast_div_modulo(tmp, ne00); + const int64_t i00 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne01); + const int64_t i01 = div_mod.y; + tmp = div_mod.x; + + div_mod = fast_div_modulo(tmp, ne02); + const int64_t i02 = div_mod.y; + const int64_t i03 = div_mod.x; + + const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd); + const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; + + dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne11); + GGML_UNUSED(ne12); + GGML_UNUSED(ne13); +} + +template<typename src_t, typename idx_t, typename dst_t> +static void set_rows_cuda( + const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + + const int64_t s01 = nb01/sizeof(src_t); + const int64_t s02 = nb02/sizeof(src_t); + const int64_t s03 = nb03/sizeof(src_t); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + const int64_t s1 = nb1/sizeof(dst_t); + const int64_t s2 = nb2/sizeof(dst_t); + const int64_t s3 = nb3/sizeof(dst_t); + + if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) { + const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00); + const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01); + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); + const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); + + k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, + s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, + ne11_fd, ne12_fd); + } +} + +template<typename src_t, typename idx_t> +static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const src_t * src0_d = (const src_t *)src0->data; + const idx_t * src1_d = (const idx_t *)src1->data; + + GGML_TENSOR_BINARY_OP_LOCALS + + cudaStream_t stream = ctx.stream(); + + + if (dst->type == GGML_TYPE_F32) { + set_rows_cuda( + src0_d, src1_d, (float*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_F16) { + set_rows_cuda( + src0_d, src1_d, (half*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_BF16) { + set_rows_cuda( + src0_d, src1_d, (nv_bfloat16*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q4_0) { + set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>( + src0_d, src1_d, (block_q4_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q4_1) { + set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>( + src0_d, src1_d, (block_q4_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_0) { + set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>( + src0_d, src1_d, (block_q5_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_1) { + set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>( + src0_d, src1_d, (block_q5_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q8_0) { + set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>( + src0_d, src1_d, (block_q8_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_IQ4_NL) { + set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>( + src0_d, src1_d, (block_iq4_nl*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else { + GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); + } +} + + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32); + + if (src1->type == GGML_TYPE_I64) { + set_rows_cuda<float, int64_t>(ctx, src0, src1, dst); + } else { + set_rows_cuda<float, int32_t>(ctx, src0, src1, dst); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh b/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh new file mode 100644 index 0000000..c140c08 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_ROWS_BLOCK_SIZE 256 + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/set.cu b/llama.cpp/ggml/src/ggml-cuda/set.cu new file mode 100644 index 0000000..04bfe07 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/set.cu @@ -0,0 +1,39 @@ +#include "set.cuh" +#include "cpy.cuh" + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)); + GGML_ASSERT(src1->type == src0->type); + GGML_ASSERT(dst ->type == src0->type); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t nb1 = ((int32_t *) dst->op_params)[0]; + const size_t nb2 = ((int32_t *) dst->op_params)[1]; + const size_t nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace= (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + ggml_cuda_cpy(ctx, src0, dst); + } + + ggml_tensor dst_view = *dst; + dst_view.data = (void *)((char *)dst->data + offset); + dst_view.ne[0] = src1->ne[0]; + dst_view.ne[1] = src1->ne[1]; + dst_view.ne[2] = src1->ne[2]; + dst_view.ne[3] = src1->ne[3]; + + dst_view.nb[0] = ggml_element_size(dst); + dst_view.nb[1] = nb1; + dst_view.nb[2] = nb2; + dst_view.nb[3] = nb3; + + ggml_cuda_cpy(ctx, src1, &dst_view); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/set.cuh b/llama.cpp/ggml/src/ggml-cuda/set.cuh new file mode 100644 index 0000000..dd09529 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/set.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_BLOCK_SIZE 256 + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/softcap.cu b/llama.cpp/ggml/src/ggml-cuda/softcap.cu new file mode 100644 index 0000000..40dfe45 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/softcap.cu @@ -0,0 +1,34 @@ +#include "softcap.cuh" + +static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = tanhf(scale * x[i]) * softcap; +} + +static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; + softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k); +} + +// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) { + const ggml_tensor * src0 = src->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + float scale; + float softcap; + memcpy(&scale, (float *) src->op_params + 0, sizeof(float)); + memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float)); + + softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/softcap.cuh b/llama.cpp/ggml/src/ggml-cuda/softcap.cuh new file mode 100644 index 0000000..6d34fb2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/softcap.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_SOFTCAP_BLOCK_SIZE 256 + +void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src); diff --git a/llama.cpp/ggml/src/ggml-cuda/softmax.cu b/llama.cpp/ggml/src/ggml-cuda/softmax.cu new file mode 100644 index 0000000..dc06d06 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/softmax.cu @@ -0,0 +1,472 @@ +#include "common.cuh" +#include "ggml.h" +#include "softmax.cuh" + +#ifdef GGML_USE_HIP +#include <hip/hip_cooperative_groups.h> +#else +#include <cooperative_groups.h> +#include <cooperative_groups/reduce.h> +#endif // GGML_USE_HIP + +#include <cstdint> +#include <utility> + +template <typename T> +static __device__ __forceinline__ float t2f32(T val) { + return (float) val; +} + +template <> +__device__ float __forceinline__ t2f32<half>(half val) { + return __half2float(val); +} + +struct soft_max_params { + + int64_t nheads; + uint32_t n_head_log2; + int64_t ncols; + int64_t nrows_x; + int64_t nrows_y; + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + int64_t nb11; + int64_t nb12; + int64_t nb13; + + int64_t ne12; + int64_t ne13; + float scale; + float max_bias; + float m0; + float m1; +}; + +// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. +// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template <bool use_shared, int ncols_template, int block_size_template, typename T> +static __global__ void soft_max_f32( + const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) { + const int ncols = ncols_template == 0 ? p.ncols : ncols_template; + + const int tid = threadIdx.x; + + const int64_t i03 = blockIdx.z; + const int64_t i02 = blockIdx.y; + const int64_t i01 = blockIdx.x; + + //TODO: noncontigous inputs/outputs + const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y; + + const int64_t i11 = i01; + const int64_t i12 = i02 % p.ne12; + const int64_t i13 = i03 % p.ne13; + + x += int64_t(rowx)*ncols; + mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr); + dst += int64_t(rowx)*ncols; + + const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + + const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); + + extern __shared__ float data_soft_max_f32[]; + float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + float * vals = use_shared ? buf_iw + WARP_SIZE : dst; + + float max_val = sinks ? sinks[i02] : -INFINITY; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f); + + vals[col] = val; + max_val = max(max_val, val); + } + + // find the max value in the block + max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw); + + float tmp = 0.0f; // partial sum + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + break; + } + + const float val = expf(vals[col] - max_val); + tmp += val; + vals[col] = val; + } + + // find the sum of exps in the block + tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw); + + if (sinks) { + tmp += expf(sinks[i02] - max_val); + } + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (ncols_template == 0 && col >= ncols) { + return; + } + + dst[col] = vals[col] * inv_sum; + } +} + +// TODO: Template to allow keeping ncols in registers if they fit +static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) { + namespace cg = cooperative_groups; + + const cg::grid_group g = cg::this_grid(); + + const int tid = threadIdx.x; + const int col_start = blockIdx.x * blockDim.x + tid; + const int n_elem_per_thread = 4; + + float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; + float local_max = -INFINITY; + const int step_size = gridDim.x * blockDim.x; + __shared__ float shared_vals[32]; + + // Compute thread-local max + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + local_max = fmaxf(local_max, local_vals[i]); + } + col += step_size * n_elem_per_thread; + } + + // Compute CTA-level max + local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals); + + // Store CTA-level max to GMEM + if (tid == 0) { + tmp_maxs[blockIdx.x] = local_max; + } + g.sync(); + + // Compute compute global max from CTA-level maxs + assert(gridDim.x < blockDim.x); // currently we only support this case + if (tid < gridDim.x) { + local_max = tmp_maxs[tid]; + } else { + local_max = -INFINITY; + } + local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals); + + // Compute softmax dividends, accumulate divisor + float tmp_expf = 0.0f; + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + const float tmp = expf(local_vals[i] - local_max); + tmp_expf += tmp; + dst[idx] = tmp; + } + } + col += step_size * n_elem_per_thread; + } + + // Reduce divisor within CTA + tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals); + + // Store CTA-level sum to GMEM + if (tid == 0) { + tmp_sums[blockIdx.x] = tmp_expf; + } + g.sync(); + + // Compute global sum from CTA-level sums + if (tid < gridDim.x) { + tmp_expf = tmp_sums[tid]; + } else { + tmp_expf = 0.0f; + } + tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals); + + // Divide dividend by global sum + store data + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + dst[idx] = local_vals[i] / tmp_expf; + } + } + col += step_size * n_elem_per_thread; + } +} + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +static __global__ void soft_max_back_f32( + const float * grad, const float * dstf, float * dst, const int ncols, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + + grad += int64_t(rowx)*ncols; + dstf += int64_t(rowx)*ncols; + dst += int64_t(rowx)*ncols; + + float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dgf_dot += dstf[col]*grad[col]; + } + + dgf_dot = warp_reduce_sum(dgf_dot); + + for (int col = tid; col < ncols; col += WARP_SIZE) { + dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; + } +} + +template<int... Ns, typename T> +static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst, + const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared) +{ + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + auto launch_kernel = [=](auto I) -> bool { + constexpr int ncols = decltype(I)::value; + constexpr int block = (ncols > 1024 ? 1024 : ncols); + + if (p.ncols == ncols) { + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo); + soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>> + (x, mask, sinks, dst, p); + return true; + } + return false; + }; + + // unary fold over launch_kernel + if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) { + return; + } + + //default case + CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo); + soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p); +} + +__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) +// We loop over all instead of parallelizing across gridDim.y as cooperative groups +// currently only support synchronizing the complete grid if not launched as a cluster group +// (which requires CC > 9.0) +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group +{ + for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) { + soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs, + tmp_sums, p); + } +} + +template <typename T> +static void soft_max_f32_cuda(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & params, + cudaStream_t stream, + [[maybe_unused]] ggml_backend_cuda_context & ctx) { + int nth = WARP_SIZE; + const int64_t ncols_x = params.ncols; + + while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2; + const dim3 block_dims(nth, 1, 1); + const dim3 block_nums(params.ne01, params.ne02, params.ne03); + const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float); + static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted."); + + + const int id = ggml_cuda_get_device(); + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + + + if (nbytes_shared <= smpbo) { + launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); + } else { + // Parallelize across SMs for top-p/dist-sampling + // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and + // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. + if (ggml_cuda_info().devices[id].supports_cooperative_launch && + ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && + params.scale == 1.0f && params.max_bias == 0.0f) { + ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + + void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr, + (void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(¶ms) }; + CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, + dim3(ggml_cuda_info().devices[id].nsm, 1, 1), + dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); + } else { + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + soft_max_f32<false, 0, 0> + <<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params); + } + } +} + +static void soft_max_back_f32_cuda( + const float * grad, const float * dstf, float * dst, + const int ncols, const int nrows, const float scale, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(nrows, 1, 1); + + soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale); +} + +void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + const float * src0_d = (const float *) src0->data; + const void * src1_d = src1 ? (const void *) src1->data : nullptr; + const void * src2_d = src2 ? (const void *) src2->data : nullptr; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional + + const int64_t nrows_x = ggml_nrows(src0); + const int64_t nrows_y = src0->ne[1]; + + const int64_t ne00 = src0->ne[0]; + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16); + + const int64_t nb11 = src1 ? src1->nb[1] : 1; + const int64_t nb12 = src1 ? src1->nb[2] : 1; + const int64_t nb13 = src1 ? src1->nb[3] : 1; + + const int64_t ne12 = src1 ? src1->ne[2] : 1; + const int64_t ne13 = src1 ? src1->ne[3] : 1; + + const uint32_t n_head = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + + soft_max_params params = {}; + params.nheads = src0->ne[2]; + params.n_head_log2 = n_head_log2; + params.ncols = ne00; + params.nrows_x = nrows_x; + params.nrows_y = nrows_y; + params.ne00 = src0->ne[0]; + params.ne01 = src0->ne[1]; + params.ne02 = src0->ne[2]; + params.ne03 = src0->ne[3]; + params.nb11 = nb11; + params.nb12 = nb12; + params.nb13 = nb13; + params.ne12 = ne12; + params.ne13 = ne13; + params.scale = scale; + params.max_bias = max_bias; + params.m0 = m0; + params.m1 = m1; + + if (use_f16) { + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); + } else { + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); + } +} + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // grad + const ggml_tensor * src1 = dst->src[1]; // forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + + GGML_ASSERT(max_bias == 0.0f); + + soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/softmax.cuh b/llama.cpp/ggml/src/ggml-cuda/softmax.cuh new file mode 100644 index 0000000..93dfee8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/softmax.cuh @@ -0,0 +1,7 @@ +#include "common.cuh" + +#define CUDA_SOFT_MAX_BLOCK_SIZE 1024 + +void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/solve_tri.cu b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cu new file mode 100644 index 0000000..177ffc2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cu @@ -0,0 +1,275 @@ +#include "common.cuh" +#include "ggml.h" +#include "solve_tri.cuh" + +#define MAX_N_FAST 64 +#define MAX_K_FAST 32 + +static __global__ void get_batch_pointers(const float * A, + float * X, + const float ** A_ptrs, + float ** X_ptrs, + int64_t ne02, + int64_t total_batches, + size_t s02, + size_t s03, + size_t s2, + size_t s3) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_batches) { + return; + } + + const int64_t i3 = idx / ne02; + const int64_t i2 = idx % ne02; + + A_ptrs[idx] = A + i3 * s03 + i2 * s02; + X_ptrs[idx] = X + i3 * s3 + i2 * s2; +} + +static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, + const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t s02, + size_t s03, + size_t s12, + size_t s13, + size_t s2, + size_t s3, + cudaStream_t stream) { + const float alpha = 1.0f; + const int64_t total_batches = ne02 * ne03; + if (total_batches == 0) { + return; + } + + // Bulk copy B -> X (contiguous tensors) + if (X != B) { + const int64_t total_elements_BX = n * k * total_batches; + CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + } + + const int id = ggml_cuda_get_device(); + + ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches); + ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches); + + const float ** A_ptrs_dev = A_ptrs_alloc.get(); + float ** X_ptrs_dev = X_ptrs_alloc.get(); + + get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02, + total_batches, s02, s03, s2, s3); + + CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); + + // Yes, this is necessary, without this we get RMSE errors + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH)); + CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N, + CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches)); + + // revert to standard mode from common.cuh + CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH)); + + GGML_UNUSED_VARS(s12, s13); +} + +// ====================== +// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction +// ====================== +// When ncols_template == 0 the bounds for the loops in this function are not +// known and can't be unrolled. As we want to keep pragma unroll for all other +// cases we supress the clang transformation warning here. +#ifdef __clang__ +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template <int n_template, int k_template> +static __global__ void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const uint3 ne02, + const size_t nb02, + const size_t nb03, + const size_t nb12, + const size_t nb13, + const size_t nb2, + const size_t nb3, + const int n_arg, + const int k_arg) { + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + + const int batch_idx = blockIdx.x; + const int lane = threadIdx.x; + const int col_idx = threadIdx.y; + + if (col_idx >= k) { + return; + } + + const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02); + const int64_t i02 = i02_i03.y; + const int64_t i03 = i02_i03.x; + + const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03); + const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3); + + __shared__ float sA[MAX_N_FAST * MAX_N_FAST]; + + const int offset = threadIdx.x + threadIdx.y * blockDim.x; + +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + const int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + __syncthreads(); + + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; + + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + +#pragma unroll + for (int row = 0; row < nrows_low; ++row) { + float sum = 0.0f; + if (lane < row) { + sum += sA[row * n + lane] * x_low; + } + sum = warp_reduce_sum(sum); + + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } + sum = warp_reduce_sum(sum); + + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; + } + } +} +#ifdef __clang__ +# pragma clang diagnostic pop +#endif // __clang__ + +static void solve_tri_f32_cuda(const float * A, + const float * B, + float * X, + int n, + int k, + int64_t ne02, + int64_t ne03, + size_t nb02, + size_t nb03, + size_t nb12, + size_t nb13, + size_t nb2, + size_t nb3, + cudaStream_t stream) { + const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02); + dim3 threads(WARP_SIZE, k); + dim3 grid(ne02 * ne03); + if (n == 64) { + switch (k) { + case 32: + solve_tri_f32_fast<64, 32> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 16: + solve_tri_f32_fast<64, 16> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 14: + solve_tri_f32_fast<64, 14> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 12: + solve_tri_f32_fast<64, 12> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 10: + solve_tri_f32_fast<64, 10> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 8: + solve_tri_f32_fast<64, 8> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 6: + solve_tri_f32_fast<64, 6> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 4: + solve_tri_f32_fast<64, 4> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 2: + solve_tri_f32_fast<64, 2> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + case 1: + solve_tri_f32_fast<64, 1> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0); + break; + default: + solve_tri_f32_fast<0, 0> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } + } else { // run general case + solve_tri_f32_fast<0, 0> + <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k); + } +} + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular) + const ggml_tensor * src1 = dst->src[1]; // B (n×k) + + ggml_is_contiguous(src0); + ggml_is_contiguous(src1); + + const int64_t n = src0->ne[0]; + const int64_t k = src1->ne[0]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + if (n <= MAX_N_FAST && k <= MAX_K_FAST) { + solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, + src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); + } else { + solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, + ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float), + src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float), + dst->nb[3] / sizeof(float), ctx.stream()); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/solve_tri.cuh b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cuh new file mode 100644 index 0000000..6399923 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu new file mode 100644 index 0000000..6d5ea70 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu @@ -0,0 +1,150 @@ +#include "ssm-conv.cuh" + +template <size_t split_d_inner, size_t d_conv> +static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, + float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, + const int64_t n_t) { + GGML_UNUSED(src0_nb0); + const int tid = threadIdx.x; + const int bidx = blockIdx.x; + const int bidy = blockIdx.y; + + const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1); + const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); + float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0); + + const int stride_x = src0_nb1 / sizeof(float); + const int stride_w = src1_nb1 / sizeof(float); + const int stride_y = dst_nb1 / sizeof(float); + + float x[d_conv] = { 0.0f }; + float w[d_conv] = { 0.0f }; + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + + for (int64_t i = 0; i < n_t; i++) { + float sumf = 0.0f; + + if (i == 0) { + for (size_t j = 0; j < d_conv; j++) { + x[j] = x_block[tid * stride_x + j]; + } + } else { + x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; + } + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + sumf += x[(i + j) % d_conv] * w[j]; + } + y_block[i * stride_y + tid] = sumf; + } +} + +template <size_t split_d_inner, size_t d_conv, int64_t split_n_t> +static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const int src0_nb0, const int src0_nb1, const int src0_nb2, + const int src1_nb1, float * __restrict__ dst, const int dst_nb0, + const int dst_nb1, const int dst_nb2, const int64_t n_t) { + const int tid = threadIdx.x; + const int bidx = blockIdx.x; + const int bidy = blockIdx.y; + const int bidz = blockIdx.z; + + const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 + + bidz * split_n_t * src0_nb0); + const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1); + float * y_block = + (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0); + + const int stride_x = src0_nb1 / sizeof(float); + const int stride_w = src1_nb1 / sizeof(float); + const int stride_y = dst_nb1 / sizeof(float); + + float x[d_conv] = { 0.0f }; + float w[d_conv] = { 0.0f }; + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + +#pragma unroll + for (int64_t i = 0; i < split_n_t; i++) { + if (bidz * split_n_t + i < n_t) { + float sumf = 0.0f; + + if (i == 0) { + for (size_t j = 0; j < d_conv; j++) { + x[j] = x_block[tid * stride_x + j]; + } + } else { + x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; + } + +#pragma unroll + for (size_t j = 0; j < d_conv; j++) { + sumf += x[(i + j) % d_conv] * w[j]; + } + y_block[i * stride_y + tid] = sumf; + } + } +} + +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, + const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, + const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, + const int64_t n_s, cudaStream_t stream) { + const int threads = 128; + GGML_ASSERT(nr % threads == 0); + + auto launch_kernel = [&](auto NC) { + constexpr int kNC = decltype(NC)::value; + if (n_t <= 32) { + const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); + ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } else { + const int64_t split_n_t = 32; + dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); + ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>( + src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } + }; + + switch (nc) { + case 3: launch_kernel(std::integral_constant<int, 3>{}); break; + case 4: launch_kernel(std::integral_constant<int, 4>{}); break; + case 9: launch_kernel(std::integral_constant<int, 9>{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now."); + } +} + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // conv_x + const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + + const int64_t nc = src1->ne[0]; // d_conv + const int64_t nr = src0->ne[1]; // d_inner + const int64_t n_t = dst->ne[1]; // tokens per sequence + const int64_t n_s = dst->ne[2]; // number of sequences in the batch + + GGML_ASSERT(dst->ne[0] == nr); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], + dst->nb[2], nc, nr, n_t, n_s, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cuh b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cuh new file mode 100644 index 0000000..8e6c1f0 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu new file mode 100644 index 0000000..c1d4e2b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu @@ -0,0 +1,342 @@ +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +#define USE_CUB +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 + +#ifdef USE_CUB +#include <cub/cub.cuh> +using namespace cub; +#endif // USE_CUB + +#include "ssm-scan.cuh" + +// We would like to keep pragma unroll for cases where L_template is not 0, +// so we suppress the clang transformation warning. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ +template <size_t splitD, size_t N, size_t L_template> +__global__ void __launch_bounds__(splitD, 1) + ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, + const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t d_inner, const int64_t L_param) +{ + const size_t L = L_template == 0 ? L_param : L_template; + const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); + const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); + const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); + const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1); + const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3)); + const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3)); + float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float)); + float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2); + + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = d_inner; + + float regA[N]; + float regs0[N]; + + __shared__ float smemB[N]; + __shared__ float smemC[N]; + +#ifdef USE_CUB + using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>; + using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>; + + union CubTempStorage { + typename BlockLoad::TempStorage load_temp; + typename BlockStore::TempStorage store_temp; + }; + __shared__ CubTempStorage cub_temp_storage; + + BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); + BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); +#else + const int stride_s0 = src0_nb2 / sizeof(float); + const int stride_A = src3_nb1 / sizeof(float); +#pragma unroll + for (size_t n = 0; n < N; ++n) + { + regA[n] = A_block[threadIdx.x * stride_A + n]; + regs0[n] = s0_block[threadIdx.x * stride_s0 + n]; + } +#endif + +#pragma unroll + for (size_t i = 0; i < L; i++) + { + if (threadIdx.x < N) + { + smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x]; + smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x]; + } + __syncthreads(); + + float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x]; + if (dt_soft_plus <= 20.0f) + { + dt_soft_plus = log1pf(expf(dt_soft_plus)); + } + float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus; + + float sumf = 0.0f; +#pragma unroll + for (size_t n = 0; n < N; n++) + { + float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt; + sumf += state * smemC[n]; + regs0[n] = state; + } + y_block[i * stride_y + threadIdx.x] = sumf; + } + +#ifdef USE_CUB + BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0); +#else + const int stride_s = stride_s0; +#pragma unroll + for (size_t n = 0; n < N; ++n) + { + s_block[threadIdx.x * stride_s + n] = regs0[n]; + } +#endif +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + +// assumes as many threads as d_state +template <int c_factor, int d_state> +__global__ void __launch_bounds__(d_state, 1) + ssm_scan_f32_group( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int warp_idx = blockIdx.x * c_factor + warp; + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; + + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + + // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + + // strides across n_seq_tokens + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = n_head * d_head; + + float state[c_factor]; + float state_sum = 0.0f; + +#pragma unroll + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; + } + + for (int64_t i = 0; i < n_tok; i++) { + // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here. + // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead. + const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]); + + state_sum = 0.0f; + const float dA = expf(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; +#pragma unroll + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; + } + + // parallel accumulation for output + state_sum = warp_reduce_sum(state_sum); + + if (lane == 0) { + y_warp[i * stride_y] = state_sum; + } + } + + // write back the state +#pragma unroll + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; + } +} + +static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3, + const float * src4, const float * src5, const int32_t * src6, float * dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, + const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, + const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, + const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, + cudaStream_t stream) { + // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! + if (src3_nb1 == sizeof(float)) { + // Mamba-2 + if (d_state == 128) { + constexpr int threads = 128; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); + } else if (d_state == 256) { // Falcon-H1 + constexpr int threads = 256; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); + } else { + GGML_ABORT("doesn't support d_state!=(128 or 256)."); + } + } else { + // Mamba-1 + constexpr int threads = 128; + GGML_ASSERT(n_head % threads == 0); + GGML_ASSERT(head_dim == 1); + GGML_ASSERT(n_group == 1); + const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); + const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); + if (d_state == 16) { + switch (n_tok) + { + case 1: + ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 2: + ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 3: + ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 4: + ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 5: + ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 6: + ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 7: + ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + case 8: + ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + default: + ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, + src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); + break; + } + } else { + GGML_ABORT("doesn't support d_state!=16."); + } + } +} + +void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const struct ggml_tensor * src0 = dst->src[0]; // s + const struct ggml_tensor * src1 = dst->src[1]; // x + const struct ggml_tensor * src2 = dst->src[2]; // dt + const struct ggml_tensor * src3 = dst->src[3]; // A + const struct ggml_tensor * src4 = dst->src[4]; // B + const struct ggml_tensor * src5 = dst->src[5]; // C + const struct ggml_tensor * src6 = dst->src[6]; // ids + + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // head_dim or 1 + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; // n_group + const int64_t n_t = src1->ne[2]; // number of tokens per sequence + const int64_t n_s = src1->ne[3]; // number of sequences in the batch + + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src2->nb[0] == sizeof(float)); + GGML_ASSERT(src3->nb[0] == sizeof(float)); + GGML_ASSERT(src4->nb[0] == sizeof(float)); + GGML_ASSERT(src5->nb[0] == sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + const float * src2_d = (const float *) src2->data; + const float * src3_d = (const float *) src3->data; + const float * src4_d = (const float *) src4->data; + const float * src5_d = (const float *) src5->data; + const int32_t * src6_d = (const int32_t *) src6->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d, + src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], + src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], + s_off, nc, nr, nh, ng, n_t, n_s, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cuh b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cuh new file mode 100644 index 0000000..ee078f5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/sum.cu b/llama.cpp/ggml/src/ggml-cuda/sum.cu new file mode 100644 index 0000000..c56257b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/sum.cu @@ -0,0 +1,41 @@ +#include "sum.cuh" +#include "sumrows.cuh" + +#ifdef GGML_CUDA_USE_CUB +#include <cub/cub.cuh> +using namespace cub; +#endif // GGML_CUDA_USE_CUB + +#include <cstdint> + +void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) { +#ifdef GGML_CUDA_USE_CUB + size_t tmp_size = 0; + DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream); + ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size); + DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream); +#else + // Use (inefficient) sum_rows implementation as a fallback. + // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14. + sum_rows_f32_cuda(x, dst, ne, 1, stream); + GGML_UNUSED(pool); +#endif // GGML_CUDA_USE_CUB +} + +void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguously_allocated(src0)); + + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + + const int64_t ne = ggml_nelements(src0); + + ggml_cuda_pool & pool = ctx.pool(); + cudaStream_t stream = ctx.stream(); + + sum_f32_cuda(pool, src0_d, dst_d, ne, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/sum.cuh b/llama.cpp/ggml/src/ggml-cuda/sum.cuh new file mode 100644 index 0000000..8cadc37 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/sum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream); + +void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/sumrows.cu b/llama.cpp/ggml/src/ggml-cuda/sumrows.cu new file mode 100644 index 0000000..4025771 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/sumrows.cu @@ -0,0 +1,43 @@ +#include "reduce_rows.cuh" +#include "sumrows.cuh" + +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + const dim3 block_nums(nrows, 1, 1); + if ((nrows / nsm) < 2) { + const dim3 block_dims(512, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); + } else { + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); + } +} + +void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const dim3 block_nums(nrows, 1, 1); + + const int id = ggml_cuda_get_device(); + const int nsm = ggml_cuda_info().devices[id].nsm; + if ((nrows / nsm) < 2) { + // Increase num threads to 512 for small nrows to better hide the latency + const dim3 block_dims(512, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + } else { + // Enough active SMs to hide latency, use smaller blocks to allow better scheduling + const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); + reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh b/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh new file mode 100644 index 0000000..3431c59 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" + +void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream); +void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu new file mode 100644 index 0000000..fb26abe --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu new file mode 100644 index 0000000..1f554d8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu new file mode 100644 index 0000000..dc16829 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu new file mode 100644 index 0000000..9d3cfd8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu new file mode 100644 index 0000000..2e1883a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu new file mode 100644 index 0000000..517993c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -0,0 +1,11 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu new file mode 100644 index 0000000..f011a20 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu new file mode 100644 index 0000000..264751d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu new file mode 100644 index 0000000..97b19c6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -0,0 +1,11 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu new file mode 100644 index 0000000..163b1d9 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu new file mode 100644 index 0000000..0543532 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu new file mode 100644 index 0000000..407b6cf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu new file mode 100644 index 0000000..f5fd0e2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu new file mode 100644 index 0000000..5e46685 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu new file mode 100644 index 0000000..989626d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -0,0 +1,11 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu new file mode 100644 index 0000000..bad296b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu new file mode 100644 index 0000000..0d7a9c7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu new file mode 100644 index 0000000..9d5a997 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu new file mode 100644 index 0000000..a6e6f09 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu new file mode 100644 index 0000000..173de7a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -0,0 +1,11 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu new file mode 100644 index 0000000..680a13c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -0,0 +1,10 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8); +DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); +DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); +DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); +DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu new file mode 100644 index 0000000..a8b15ad --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu new file mode 100644 index 0000000..1da1810 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu new file mode 100644 index 0000000..bc65c72 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu new file mode 100644 index 0000000..10b330f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu new file mode 100644 index 0000000..254b7d2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu new file mode 100644 index 0000000..5caffac --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu new file mode 100644 index 0000000..8f9d531 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(72, 72); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu new file mode 100644 index 0000000..90abb3b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu new file mode 100644 index 0000000..7292c0a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu new file mode 100644 index 0000000..c357abd --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu new file mode 100644 index 0000000..4b14865 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu new file mode 100644 index 0000000..ef77157 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu new file mode 100644 index 0000000..9ae11cc --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu new file mode 100644 index 0000000..10ed48a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu new file mode 100644 index 0000000..4fcc3f3 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu new file mode 100644 index 0000000..7ca5053 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu new file mode 100644 index 0000000..6ef1a48 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu new file mode 100644 index 0000000..4c0532c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu new file mode 100644 index 0000000..ed3d7ba --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu new file mode 100644 index 0000000..687f254 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu new file mode 100644 index 0000000..41107c4 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu new file mode 100644 index 0000000..d523ce0 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu new file mode 100644 index 0000000..8b9ed35 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu new file mode 100644 index 0000000..0553e46 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu new file mode 100644 index 0000000..8390eaf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu new file mode 100644 index 0000000..f61e19d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu new file mode 100644 index 0000000..86a1882 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu new file mode 100644 index 0000000..1d7af47 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu new file mode 100644 index 0000000..837224d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu new file mode 100644 index 0000000..0dd7dd6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu new file mode 100644 index 0000000..41b859f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu new file mode 100644 index 0000000..d2e5ffd --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu new file mode 100644 index 0000000..81ff740 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu new file mode 100644 index 0000000..a38dae1 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu new file mode 100644 index 0000000..2304571 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu new file mode 100644 index 0000000..84b83e5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu new file mode 100644 index 0000000..39f80e2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu new file mode 100644 index 0000000..cf4e661 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu new file mode 100644 index 0000000..6565418 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu new file mode 100644 index 0000000..a1bc3f5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu new file mode 100644 index 0000000..4b76a9b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu new file mode 100644 index 0000000..77d0412 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu new file mode 100644 index 0000000..6e170fe --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu new file mode 100644 index 0000000..b617cd7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu new file mode 100644 index 0000000..a5b768b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py new file mode 100755 index 0000000..e382df1 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 + +from glob import glob +import os + +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] + +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] + +SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v}); +""" + +SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, {type_k}, {type_v}); +DECL_FATTN_VEC_CASE(128, {type_k}, {type_v}); +DECL_FATTN_VEC_CASE(256, {type_k}, {type_v}); +""" + +SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +""" + +SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" + +TYPES_MMQ = [ + "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", + "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", + "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" +] + +SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE({type}); +""" + +SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE({type}); +""" + + +def get_short_name(long_quant_name): + return long_quant_name.replace("GGML_TYPE_", "").lower() + + +for filename in glob("*.cu"): + os.remove(filename) + +for head_size_kq in HEAD_SIZES_KQ: + head_size_v = head_size_kq if head_size_kq != 576 else 512 + with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: + f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) + +for type_k in TYPES_KV: + for type_v in TYPES_KV: + with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: + f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v)) + +for ncols in [8, 16, 32, 64]: + for ncols2 in [1, 2, 4, 8, 16, 32]: + if ncols2 > ncols: + continue + ncols1 = ncols // ncols2 + with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: + f.write(SOURCE_FATTN_MMA_START) + + for head_size_kq in HEAD_SIZES_KQ: + if head_size_kq == 40: + continue + if head_size_kq == 72: + continue + if head_size_kq != 576 and ncols2 in (16, 32): + continue + if head_size_kq == 576 and ncols2 not in (4, 16, 32): + continue + head_size_v = head_size_kq if head_size_kq != 576 else 512 + f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) + +for type in TYPES_MMQ: + with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: + f.write(SOURCE_MMQ.format(type=type)) + +for type in range(1, 17): + with open(f"mmf-instance-ncols_{type}.cu", "w") as f: + f.write(SOURCE_MMF.format(type=type)) diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu new file mode 100644 index 0000000..f594d5d --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu new file mode 100644 index 0000000..9cc6772 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(10); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu new file mode 100644 index 0000000..317f487 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(11); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu new file mode 100644 index 0000000..dc00332 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(12); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu new file mode 100644 index 0000000..0782101 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(13); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu new file mode 100644 index 0000000..a23ad6a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(14); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu new file mode 100644 index 0000000..0fe3f78 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(15); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu new file mode 100644 index 0000000..5440863 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(16); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu new file mode 100644 index 0000000..3b90179 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(2); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu new file mode 100644 index 0000000..56e940b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(3); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu new file mode 100644 index 0000000..a7665d4 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(4); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu new file mode 100644 index 0000000..3a1dff2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(5); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu new file mode 100644 index 0000000..400fb7c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(6); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu new file mode 100644 index 0000000..954a1c7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(7); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu new file mode 100644 index 0000000..f1bd09c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(8); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu new file mode 100644 index 0000000..1255ac2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(9); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu new file mode 100644 index 0000000..84ec850 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu new file mode 100644 index 0000000..583c4e5 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu new file mode 100644 index 0000000..edaf156 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu new file mode 100644 index 0000000..233d934 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu new file mode 100644 index 0000000..6092dc7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu new file mode 100644 index 0000000..1d5bd20 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu new file mode 100644 index 0000000..eb02fab --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu new file mode 100644 index 0000000..1eb3b74 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu new file mode 100644 index 0000000..c14624c --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_MXFP4); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu new file mode 100644 index 0000000..6415369 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q2_K); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu new file mode 100644 index 0000000..ffb6213 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q3_K); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu new file mode 100644 index 0000000..0c0b0c8 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu new file mode 100644 index 0000000..ee67f69 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu new file mode 100644 index 0000000..9eeb3cd --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu new file mode 100644 index 0000000..cc57fb9 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu new file mode 100644 index 0000000..721ac79 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_1); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu new file mode 100644 index 0000000..a2e90ff --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q5_K); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu new file mode 100644 index 0000000..470938f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q6_K); diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu new file mode 100644 index 0000000..974477b --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q8_0); diff --git a/llama.cpp/ggml/src/ggml-cuda/top-k.cu b/llama.cpp/ggml/src/ggml-cuda/top-k.cu new file mode 100644 index 0000000..785a183 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/top-k.cu @@ -0,0 +1,95 @@ +#include "argsort.cuh" +#include "top-k.cuh" + +#ifdef GGML_CUDA_USE_CUB +# include <cub/cub.cuh> +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) +# define CUB_TOP_K_AVAILABLE +using namespace cub; +# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 +#endif // GGML_CUDA_USE_CUB + +#ifdef CUB_TOP_K_AVAILABLE + +static void top_k_cub(ggml_cuda_pool & pool, + const float * src, + int * dst, + const int ncols, + const int k, + cudaStream_t stream) { + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); + auto stream_env = cuda::stream_ref{ stream }; + auto env = cuda::std::execution::env{ stream_env, requirements }; + + auto indexes_in = cuda::make_counting_iterator(0); + + size_t temp_storage_bytes = 0; + DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, + env); + + ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, + ncols, k, env); +} + +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +#endif // CUB_TOP_K_AVAILABLE + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + int * dst_d = (int *) dst->data; + cudaStream_t stream = ctx.stream(); + + // are these asserts truly necessary? + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + const int64_t k = dst->ne[0]; + ggml_cuda_pool & pool = ctx.pool(); +#ifdef CUB_TOP_K_AVAILABLE + // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented + // https://github.com/NVIDIA/cccl/issues/6391 + // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k + for (int i = 0; i < nrows; i++) { + top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream); + } +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + // Fall back to argsort + copy + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + + if (shared_mem > max_shared_mem || ncols > 1024) { + argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#else // GGML_CUDA_USE_CUB + ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#endif +} diff --git a/llama.cpp/ggml/src/ggml-cuda/top-k.cuh b/llama.cpp/ggml/src/ggml-cuda/top-k.cuh new file mode 100644 index 0000000..f4d8f61 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/top-k.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu new file mode 100644 index 0000000..08a8899 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu @@ -0,0 +1,403 @@ +#include "ggml-cuda/common.cuh" +#include "ggml.h" +#include "topk-moe.cuh" + +#include <cmath> +#include <initializer_list> + +// Kernel config struct - passed by value to CUDA kernel +struct topk_moe_config { + bool use_sigmoid; + bool with_norm; + bool delayed_softmax; +}; + +// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. +template <int experts_per_thread, bool use_limit> +__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { + float max_val = -INFINITY; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + max_val = max(max_val, vals[i]); + } + } + + max_val = warp_reduce_max(max_val); + + float sum = 0.f; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + const float val = expf(vals[i] - max_val); + vals[i] = val; + sum += val; + } else { + vals[i] = 0.f; + } + } + + sum = warp_reduce_sum(sum); + + const float inv_sum = 1.0f / sum; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + if (active) { + vals[i] *= inv_sum; + } + } +} + +template <int experts_per_thread, bool use_limit> +__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY; + } +} + +/* + This kernel does the following: + 1. optionally softmax over the logits per token [n_experts, n_tokens] + 2. argmax reduce over the top-k (n_experts_used) logits + 3. write weights + ids to global memory + 4. optionally normalize the weights or apply softmax over the selected logits + + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models +*/ +template <int n_experts, bool has_bias> +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + const int row = blockIdx.x * blockDim.y + threadIdx.y; + if (row >= n_rows) { + return; + } + + logits += n_experts * row; + weights += n_expert_used * row; + ids += n_experts * row; + + constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; + + float wt[experts_per_thread]; + + // Initialize all slots to -INFINITY +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = -INFINITY; + } + +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; + } + + if (!config.delayed_softmax) { + if (config.use_sigmoid) { + sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); + } else { + softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); + } + } + + // selection_wt is only needed when bias is present (selection uses wt + bias) + // when no bias, we use wt directly for both selection and weight values + float selection_wt[has_bias ? experts_per_thread : 1]; + + if constexpr (has_bias) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + selection_wt[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + selection_wt[i / WARP_SIZE] = + (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY; + } + } + + //at this point, each thread holds either a portion of the softmax distribution + //or the raw logits. We do the argmax reduce over n_expert_used, each time marking + //the expert weight as -inf to exclude from the next iteration + + float wt_sum = 0.f; + + float output_weights[experts_per_thread]; + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] = 0.f; + } + + for (int k = 0; k < n_expert_used; k++) { + float max_val = wt[0]; + int max_expert = threadIdx.x; + + if constexpr (has_bias) { + float max_val_s = selection_wt[0]; + +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) { + max_val = wt[i]; + max_val_s = selection_wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { + max_val = val; + max_val_s = val_s; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + selection_wt[max_expert / WARP_SIZE] = -INFINITY; + } + } else { +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; + } + } + + if ((k & (WARP_SIZE - 1)) == threadIdx.x) { + output_weights[k / WARP_SIZE] = max_val; + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + ids[k] = max_expert; + if (config.with_norm) { + wt_sum += max_val; + } + } + } + + if (config.with_norm) { + wt_sum = warp_reduce_sum(wt_sum); + wt_sum = max(wt_sum, clamp_val); + const float inv_sum = 1.0f / wt_sum; + + for (int i = 0; i < experts_per_thread; i++) { + output_weights[i] *= inv_sum; + } + } + + if (config.delayed_softmax) { + softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x); + } + +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = i * WARP_SIZE + threadIdx.x; + if (idx < n_expert_used) { + weights[idx] = output_weights[i] * scale_val; + } + } +} + +template<bool has_bias> +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, + const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + GGML_ASSERT(!(config.with_norm && config.delayed_softmax) && + "delayed softmax is not supported with weight normalization"); + const int rows_per_block = 4; + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); + dim3 block_dims(WARP_SIZE, rows_per_block, 1); + cudaStream_t stream = ctx.stream(); + + switch (n_expert) { + case 1: + topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 2: + topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 4: + topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 8: + topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 16: + topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 32: + topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 64: + topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 128: + topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 256: + topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 512: + topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 576: + topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + default: + GGML_ASSERT(false && "fatal error"); + break; + } +} + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args) { + GGML_ASSERT(logits->type == GGML_TYPE_F32); + GGML_ASSERT(weights->type == GGML_TYPE_F32); + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const int n_experts = logits->ne[0]; + const int n_rows = logits->ne[1]; + + const float * logits_d = (const float *) logits->data; + float * weights_d = (float *) weights->data; + int32_t * ids_d = (int32_t *) ids->data; + float * bias_d = bias ? (float *) bias->data : nullptr; + + float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f; + + GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); + + const int n_expert_used = weights->ne[1]; + + const bool with_norm = clamp != nullptr; + + float clamp_val = -INFINITY; + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + + topk_moe_config config; + config.use_sigmoid = args.sigmoid; + config.with_norm = with_norm; + config.delayed_softmax = args.delayed_softmax; + + if (bias) { + launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); + } else { + launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); + } +} + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, + const ggml_tensor * weights, + const ggml_tensor * logits, + const ggml_tensor * ids) { + const int n_expert = ids->nb[1] / ids->nb[0]; + if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) { + return false; + } + + if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) { + return false; + } + + if (gating_op->op == GGML_OP_SOFT_MAX) { + const ggml_tensor * softmax = gating_op; + float scale = 1.0f; + float max_bias = 0.0f; + + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + + if (!ggml_is_contiguous(softmax->src[0])) { + return false; + } + + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } + + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { + return false; + } + } else if (gating_op->op == GGML_OP_UNARY) { + ggml_unary_op op = ggml_get_unary_op(gating_op); + + if (op != GGML_UNARY_OP_SIGMOID) { + return false; + } + } + + return true; +} diff --git a/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh new file mode 100644 index 0000000..243dc2f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh @@ -0,0 +1,27 @@ +#include "common.cuh" +#include "ggml.h" + +#include <initializer_list> + +struct ggml_cuda_topk_moe_args { + bool sigmoid{}; + bool softmax{}; + bool delayed_softmax{}; + bool prob_bias{}; + bool norm{}; + bool scale{}; +}; + +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args); + +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, + const ggml_tensor * weights, + const ggml_tensor * logits, + const ggml_tensor * ids); diff --git a/llama.cpp/ggml/src/ggml-cuda/tri.cu b/llama.cpp/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 0000000..44156b6 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,136 @@ +#include "common.cuh" +#include "convert.cuh" +#include "tri.cuh" +#include "ggml.h" + +template<typename T, bool prefix_keep, int add_to_split> +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + const int64_t split_point = i1 + add_to_split; + + GGML_UNUSED_VARS(nb00, nb0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; + T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; + + if constexpr (prefix_keep) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = ggml_cuda_cast<T, float>(0.0f); + } + } else { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = ggml_cuda_cast<T, float>(0.0f); + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + } +} + +template<typename T> +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + const size_t type_size = sizeof(T); + + const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; + const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); + + if (prefix_keep) { + if (add_to_split == 0) { + tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { // only 0 and 1 supported + tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } else { + if (add_to_split == 0) { + tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0)); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/tri.cuh b/llama.cpp/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 0000000..a4cc667 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/tsembd.cu b/llama.cpp/ggml/src/ggml-cuda/tsembd.cu new file mode 100644 index 0000000..b91a26f --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/tsembd.cu @@ -0,0 +1,47 @@ +#include "tsembd.cuh" + +static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) { + // blockIDx.y: idx of timesteps->ne[0] + // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE + int i = blockIdx.y; + int j = threadIdx.x + blockIdx.x * blockDim.x; + float * embed_data = (float *)((char *)dst + i*nb1); + + int half = dim / 2; + if (dim % 2 != 0 && j == half) { + embed_data[2 * half] = 0.f; + } + + if (j >= half) { + return; + } + + float timestep = timesteps[i]; + float freq = (float)expf(-logf(max_period) * j / half); + float arg = timestep * freq; + embed_data[j] = cosf(arg); + embed_data[j + half] = sinf(arg); +} + +static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1, + const int dim, const int max_period, cudaStream_t stream) { + int half_ceil = (dim + 1) / 2; + int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE; + dim3 gridDim(num_blocks, ne00, 1); + timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period); +} + +void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + + timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream); +} diff --git a/llama.cpp/ggml/src/ggml-cuda/tsembd.cuh b/llama.cpp/ggml/src/ggml-cuda/tsembd.cuh new file mode 100644 index 0000000..84340e3 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/tsembd.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256 + +void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/unary.cu b/llama.cpp/ggml/src/ggml-cuda/unary.cu new file mode 100644 index 0000000..d486606 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/unary.cu @@ -0,0 +1,562 @@ +#include "unary.cuh" +#include "convert.cuh" + +static __device__ __forceinline__ float op_abs(float x) { + return fabsf(x); +} + +static __device__ __forceinline__ float op_sgn(float x) { + return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f))); +} + +static __device__ __forceinline__ float op_neg(float x) { + return -x; +} + +static __device__ __forceinline__ float op_step(float x) { + return x > 0.0f; +} + +static __device__ __forceinline__ float op_gelu(float x) { + return ggml_cuda_op_gelu_single(x); +} + +static __device__ __forceinline__ float op_gelu_erf(float x) { + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + + return 0.5f*x*(1.0f + erff(x*SQRT_2_INV)); +} + +static __device__ __forceinline__ float op_gelu_quick(float x) { + const float GELU_QUICK_COEF = -1.702f; + + return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x))); +} + +static __device__ __forceinline__ float op_silu(float x) { + return ggml_cuda_op_silu_single(x); +} + +static __device__ __forceinline__ float op_tanh(float x) { + return tanhf(x); +} + +static __device__ __forceinline__ float op_relu(float x) { + return fmaxf(x, 0); +} + +static __device__ __forceinline__ float op_sigmoid(float x) { + return 1.0f / (1.0f + expf(-x)); +} + +static __device__ __forceinline__ float op_hardsigmoid(float x) { + return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); +} + +static __device__ __forceinline__ float op_hardswish(float x) { + return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f)); +} + +static __device__ __forceinline__ float op_exp(float x) { + return expf(x); +} + +static __device__ __forceinline__ float op_sqr(float x) { + return x * x; +} + +static __device__ __forceinline__ float op_sqrt(float x) { + return sqrtf(x); +} + +static __device__ __forceinline__ float op_sin(float x) { + return sinf(x); +} + +static __device__ __forceinline__ float op_cos(float x) { + return cosf(x); +} + +static __device__ __forceinline__ float op_log(float x) { + return logf(x); +} + +static __device__ __forceinline__ float op_expm1(float x) { + return expm1f(x); +} + +static __device__ __forceinline__ float op_softplus(float x) { + return (x > 20.0f) ? x : logf(1.0f + expf(x)); +} + +static __device__ __forceinline__ float op_elu(float x) { + return (x > 0.f) ? x : expm1f(x); +} + +static __device__ __forceinline__ float op_floor(float x) { + return floorf(x); +} + +static __device__ __forceinline__ float op_ceil(float x) { + return ceilf(x); +} + +static __device__ __forceinline__ float op_round(float x) { + return round(x); +} + +static __device__ __forceinline__ float op_trunc(float x) { + return trunc(x); +} + +template <float (*op)(float), typename T> +static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op((float)x[i]); +} + +template <float (*op)(float), typename T> +static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; + unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k); +} + +template <float (*op)(float)> +void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + if (src0->type == GGML_TYPE_F16) { + unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream); + } else { + unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream); + } +} + +void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_abs>(ctx, dst); +} + +void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_sgn>(ctx, dst); +} + +void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_neg>(ctx, dst); +} + +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_step>(ctx, dst); +} + +void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_gelu>(ctx, dst); +} + +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_gelu_erf>(ctx, dst); +} + +void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_gelu_quick>(ctx, dst); +} + +void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_silu>(ctx, dst); +} + +void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_tanh>(ctx, dst); +} + +void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_relu>(ctx, dst); +} + +void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_sigmoid>(ctx, dst); +} + +void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst); +} + +void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_hardswish>(ctx, dst); +} + +void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_exp>(ctx, dst); +} + +void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_sqr>(ctx, dst); +} + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_sqrt>(ctx, dst); +} + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_sin>(ctx, dst); +} + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_cos>(ctx, dst); +} + +void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_log>(ctx, dst); +} + +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_elu>(ctx, dst); +} + +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_floor>(ctx, dst); +} + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_ceil>(ctx, dst); +} + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_round>(ctx, dst); +} + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_trunc>(ctx, dst); +} + +void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_expm1>(ctx, dst); +} + +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary<op_softplus>(ctx, dst); +} +/* gated ops */ + +template <float (*op)(float), typename T> +static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); +} + +template <float (*op)(float), typename T> +static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; + unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1); +} + +template <float (*op)(float)> +void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + + const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + + if (src0->type == GGML_TYPE_F16) { + half * src0_p = (half *) src0_d; + half * src1_p = (half *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream); + } else { + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream); + } +} + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated<op_relu>(ctx, dst); +} + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated<op_gelu>(ctx, dst); +} + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated<op_silu>(ctx, dst); +} + +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst); +} + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst); +} + +// swiglu_oai + +template <typename T> +static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) { + const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + // perform base op and multiply with gate (either offset in same tensor or a separate one) + const int64_t j0 = (i / n) * o0 + (i % n); + const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + + float xi = x[j0]; + float gi = g[j1]; + + dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit); +} + +template <typename T> +static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) { + const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; + swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit); +} + +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + void * src0_d = src0->data; + void * src1_d = src1 ? src1->data : src0->data; + const int64_t src0_o = src0->nb[1]; + const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1]; + void * dst_d = dst->data; + const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(src0->nb[0] == ggml_element_size(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(dst->ne[0] == nc); + GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0)); + + if (src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src1->nb[0] == ggml_element_size(src1)); + GGML_ASSERT(src1->ne[0] == nc); + GGML_ASSERT(src0->type == src1->type); + } + + //const int32_t swapped = ((const int32_t *) dst->op_params)[1]; + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float limit = ggml_get_op_params_f32(dst, 3); + + float * src0_p = (float *) src0_d; + float * src1_p = (float *) src1_d; + + if (!src1) { + src0_p += swapped ? nc : 0; + src1_p += swapped ? 0 : nc; + } + + swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); +} + +/* CUDA kernel + launcher for xIELU */ + +template <typename T> +static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xi = ggml_cuda_cast<float>(x[i]); + + const float gate_pos = (xi > 0.0f); + const float y_pos = alpha_p * xi * xi + beta * xi; + const float min_v_eps = fminf(xi, eps); + const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; + const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + + dst[i] = ggml_cuda_cast<T>(out); +} + +template <typename T> +static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) { + const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE; + xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + if (src0->type == GGML_TYPE_F16) { + xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } else { + xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } +} + + + +/* silu_back */ + +static __device__ __forceinline__ float op_silu_back(float grad, float x) { + const float s = 1.0f / (1.0f + expf(-x)); + return grad * s * (1.0f + x * (1.0f - s)); +} + +template <class T> +static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]); +} + +template <class T> +static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; + silu_back_kernel<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k); +} + +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; // input from forward pass + const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + if (src0->type == GGML_TYPE_F16) { + silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream); + } else { + silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream); + } +} + +/* leaky relu */ + +static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) { + return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope; +} + +template <class T> +static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = (T)op_leaky_relu((float)x[i], negative_slope); +} + +template <class T> +static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) { + const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; + leaky_relu_kernel<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope); +} + +void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + float negative_slope; + memcpy(&negative_slope, dst->op_params, sizeof(float)); + + if (src0->type == GGML_TYPE_F16) { + leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream); + } else { + leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/unary.cuh b/llama.cpp/ggml/src/ggml-cuda/unary.cuh new file mode 100644 index 0000000..609046e --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/unary.cuh @@ -0,0 +1,110 @@ +#pragma once +#include "common.cuh" + +#define CUDA_NEG_BLOCK_SIZE 256 +#define CUDA_STEP_BLOCK_SIZE 256 +#define CUDA_GELU_BLOCK_SIZE 256 +#define CUDA_SILU_BLOCK_SIZE 256 +#define CUDA_SILU_BACK_BLOCK_SIZE 256 +#define CUDA_TANH_BLOCK_SIZE 256 +#define CUDA_RELU_BLOCK_SIZE 256 +#define CUDA_SIGMOID_BLOCK_SIZE 256 +#define CUDA_HARDSIGMOID_BLOCK_SIZE 256 +#define CUDA_EXP_BLOCK_SIZE 256 +#define CUDA_HARDSWISH_BLOCK_SIZE 256 +#define CUDA_SQR_BLOCK_SIZE 256 +#define CUDA_SQRT_BLOCK_SIZE 256 +#define CUDA_SIN_BLOCK_SIZE 256 +#define CUDA_COS_BLOCK_SIZE 256 +#define CUDA_GLU_BLOCK_SIZE 256 +#define CUDA_XIELU_BLOCK_SIZE 256 + +void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { + return x / (1.0f + expf(-x)); +} + +__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + + return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x))); +} + +__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) { + x = fminf(x, limit); + g = fmaxf(fminf(g, limit), -limit); + + float out_glu = x / (1.0f + expf(-x * alpha)); + out_glu = out_glu * (1.0f + g); + return out_glu; +} diff --git a/llama.cpp/ggml/src/ggml-cuda/upscale.cu b/llama.cpp/ggml/src/ggml-cuda/upscale.cu new file mode 100644 index 0000000..6bdf3cd --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/upscale.cu @@ -0,0 +1,293 @@ +#include "upscale.cuh" + +static __global__ void upscale_f32(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) ); +} + +static __global__ void upscale_f32_bilinear(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + int y0_src = (int)floorf(y_src_f); + int y1_src = y0_src + 1; + + y0_src = max(0, min(y0_src, ne01_src - 1)); + y1_src = max(0, min(y1_src, ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = max(0.0f, min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + int x0_src = (int)floorf(x_src_f); + int x1_src = x0_src + 1; + + x0_src = max(0, min(x0_src, ne00_src - 1)); + x1_src = max(0, min(x1_src, ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = max(0.0f, min(dx, 1.0f)); + + const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst[index] = result; +} + +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return max(1.0f - fabsf(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + +namespace bicubic_interpolation { +// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm +__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch) + +static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; }; +static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; }; + +static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) { + const float w0 = weight2(x + 1); + const float w1 = weight1(x + 0); + const float w2 = weight1(1 - x); + const float w3 = weight2(2 - x); + return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3; +}; +} // namespace bicubic_interpolation + +static __global__ void upscale_f32_bicubic(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + using bicubic_interpolation::bicubic; + + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + const int y0_src = (int)floorf(y_src_f); + const float dy = y_src_f - (float)y0_src; + + const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + const int x0_src = (int)floorf(x_src_f); + const float dx = x_src_f - (float)x0_src; + + const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03; + + auto load = [=](int x_off, int y_off) -> float { + int i00_src = max(0, min(x0_src + x_off, ne00_src - 1)); + int i01_src = max(0, min(y0_src + y_off, ne01_src - 1)); + return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01); + }; + + const float result = bicubic( + bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx), + bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx), + bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx), + bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy); + + dst[index] = result; +} + +static void upscale_f32_cuda(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3, + cudaStream_t stream) { + const int64_t dst_size = ne10 * ne11 * ne12 * ne13; + const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + + upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); +} + +static void upscale_f32_bilinear_cuda(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset, bool antialias, cudaStream_t stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + + if (antialias) { + upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } else { + upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } +} + +static void upscale_f32_bicubic_cuda(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset, cudaStream_t stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; + + upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); +} + +void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int mode_flags = dst->op_params[0]; + const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF); + + float sf0 = (float)dst->ne[0]/src0->ne[0]; + float sf1 = (float)dst->ne[1]/src0->ne[1]; + float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; + + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1; + pixel_offset = 0.0f; + } + + if (mode == GGML_SCALE_MODE_NEAREST) { + upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, stream); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/upscale.cuh b/llama.cpp/ggml/src/ggml-cuda/upscale.cuh new file mode 100644 index 0000000..d4d7652 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/upscale.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_UPSCALE_BLOCK_SIZE 256 + +void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh b/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh new file mode 100644 index 0000000..6baab11 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh @@ -0,0 +1,1223 @@ +#pragma once + +#include "common.cuh" + +#include <cstdint> + +static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) { + const uint8_t * x8 = (const uint8_t *) x; + + int x32 = x8[4*i32 + 0] << 0; + x32 |= x8[4*i32 + 1] << 8; + x32 |= x8[4*i32 + 2] << 16; + x32 |= x8[4*i32 + 3] << 24; + + return x32; +} + +static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} + +// q4 contains 8 indices with 4 bit each. +// This function selects those bytes from table that are at those indices and returns them as int2. +// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4. +static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) { +#if defined(GGML_USE_HIP) + // Load the 16-byte table into four 32-bit unsigned integers. + const uint32_t *values = (const uint32_t *)table; + + const uint32_t q_even = q4; + const uint32_t q_odd = (q4 >> 4); + + // Perform lookups in the lower half of the table (indices 0-7). + uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707); + uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707); + + // Perform lookups in the upper half of the table (indices 8-15). + uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707); + uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707); + + // Select between the low and high results based on the MSB of each index nibble. + uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1); + uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even); + uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1); + uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd); + + return make_int2(res_x, res_y); +#elif !defined(GGML_USE_MUSA) + // CUDA does not have an instruction for selecting bytes with 4 bit indices. + // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead. + const uint32_t * table32 = (const uint32_t *) table; + + // __byte_perm selects bytes based on the lower 16 bits in its third argument. + // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift. + // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits. + // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit. + uint32_t tmp[2]; + const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1)); +#pragma unroll + for (uint32_t i = 0; i < 2; ++i) { + const uint32_t shift = 16 * i; + + const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift); + const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift); + tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift); + } + + // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4. + // However, for the result we need ints with all even/odd 4 bit indices in q4. + // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order. + return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531)); +#else + // Generic implementation. + const int q0_32 = (q4 >> 0) & 0x0F0F0F0F; + const int8_t * q0_8 = (const int8_t *) &q0_32; + const char4 val0_8 = make_char4( + table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]); + + const int q1_32 = (q4 >> 4) & 0x0F0F0F0F; + const int8_t * q1_8 = (const int8_t *) &q1_32; + const char4 val1_8 = make_char4( + table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]); + + return make_int2(*((const int *) &val0_8), *((const int *) &val1_8)); +#endif +} + +// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called +// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q + +#define VDR_Q4_0_Q8_1_MMVQ 2 +#define VDR_Q4_0_Q8_1_MMQ 4 + +template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( + const int * v, const int * u, const float & d4, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 8 from each quant value + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); +} + +#define VDR_Q4_1_Q8_1_MMVQ 2 +#define VDR_Q4_1_Q8_1_MMQ 4 + +template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( + const int * v, const int * u, const half2 & dm4, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + const int vi0 = (v[i] >> 0) & 0x0F0F0F0F; + const int vi1 = (v[i] >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); + } + +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; +#else + const float2 dm4f = __half22float2(dm4); + const float2 ds8f = __half22float2(ds8); + const float d4d8 = dm4f.x * ds8f.x; + const float m4s8 = dm4f.y * ds8f.y; +#endif // FAST_FP16_AVAILABLE + + // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it + return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1)); +} + +#define VDR_Q5_0_Q8_1_MMVQ 2 +#define VDR_Q5_0_Q8_1_MMQ 4 + +template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( + const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + + const float2 ds8f = __half22float2(ds8); + + // second part effectively subtracts 16 from each quant value + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); +} + +#define VDR_Q5_1_Q8_1_MMVQ 2 +#define VDR_Q5_1_Q8_1_MMQ 4 + +template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( + const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits + vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 + vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12 + vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20 + vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28 + sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values + + int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits + vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4 + vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12 + vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20 + vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28 + sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values + } + +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; +#else + const float2 dm5f = __half22float2(dm5); + const float2 ds8f = __half22float2(ds8); + const float d5d8 = dm5f.x * ds8f.x; + const float m5s8 = dm5f.y * ds8f.y; +#endif // FAST_FP16_AVAILABLE + + // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it + return sumi*d5d8 + m5s8 / (QI5_1 / vdr); +} + +#define VDR_Q8_0_Q8_1_MMVQ 2 +#define VDR_Q8_0_Q8_1_MMQ 8 + +template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl( + const int * v, const int * u, const T & d8_0, const T & d8_1) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * ((T) sumi); +} + +template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( + const int * v, const int * u, const half2 & dm8, const half2 & ds8) { + + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; +#else + const float2 dm8f = __half22float2(dm8); + const float2 ds8f = __half22float2(ds8); + const float d8d8 = dm8f.x * ds8f.x; + const float m8s8 = dm8f.y * ds8f.y; +#endif // FAST_FP16_AVAILABLE + + // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it + return sumi*d8d8 + m8s8 / (QI8_1 / vdr); +} + +template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl( + const int * v, const int * u, const float * d8_0, const float & d8_1) { + + float sumf = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) { + int sumi = 0; + +#pragma unroll + for (int i = i0; i < i0 + QI8_0/2; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + + sumf += d8_0[i0/(QI8_0/2)]*sumi; + } + + return d8_1*sumf; +} + +#define VDR_MXFP4_Q8_1_MMVQ 2 +#define VDR_MXFP4_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx; + + const int * q8 = (const int *) bq8_1->qs + iqs; + + int sumi = 0; +#pragma unroll + for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) { + const int aux_q4 = get_int_b1(bq4->qs, iqs + l); + const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4); + + sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi); + sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi); + } + + const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds); + return d * sumi; +} + +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 4 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +} + +// contiguous v/x + u/y values +template <int ns8> +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) { + + float sumf = 0.0f; + float sumf_d8 = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) { + const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]); + int sumi_d0 = 0; + + const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]); + int sumi_d1 = 0; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0); + } + sumf_d8 += dm2f0.x * sumi_d0; + +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1); + } + sumf_d8 += dm2f1.x * sumi_d1; + + if (i0/QI8_1 < ns8) { + const float2 s8f = __half22float2(s8[i0/QI8_1]); + sumf -= dm2f0.y*s8f.x; + sumf -= dm2f1.y*s8f.y; + } else { + int sumi_m0 = 0; +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0); + } + sumf_d8 -= dm2f0.y * sumi_m0; + + int sumi_m1 = 0; +#pragma unroll + for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) { + sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1); + } + sumf_d8 -= dm2f1.y * sumi_m1; + } + } + + return sumf + d8*sumf_d8; +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +} + +// contiguous v/x + u/y values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { + + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +// contiguous v/x + u/y values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const float2 dm5f = __half22float2(dm5); + + return dm5f.x*sumf_d - dm5f.y*sumf_m; +} + +// contiguous v/x + u/y values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) { + int sumi_d = 0; + +#pragma unroll + for (int j = 0; j < QI8_1; ++j) { + sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i]); + + sumf_d += ds8f.x * (sc[i] * sumi_d); + sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +} + +// contiguous v/x + u/y values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { + + float sumf_d = 0.0f; + + const int sc_packed = get_int_b4(sc, 0); + const int8_t * sc_reg = (const int8_t *) &sc_packed; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; +} + +static __device__ __forceinline__ float vec_dot_q4_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx; + + int v[VDR_Q4_0_Q8_1_MMVQ]; + int u[2*VDR_Q4_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_b2(bq4_0->qs, iqs + i); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0); + } + + return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds); +} + + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx; + + int v[VDR_Q4_1_Q8_1_MMVQ]; + int u[2*VDR_Q4_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) { + v[i] = get_int_b4(bq4_1->qs, iqs + i); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1); + } + + return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx; + + int vl[VDR_Q5_0_Q8_1_MMVQ]; + int vh[VDR_Q5_0_Q8_1_MMVQ]; + int u[2*VDR_Q5_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) { + vl[i] = get_int_b2(bq5_0->qs, iqs + i); + vh[i] = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0); + } + + return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx; + + int vl[VDR_Q5_1_Q8_1_MMVQ]; + int vh[VDR_Q5_1_Q8_1_MMVQ]; + int u[2*VDR_Q5_1_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) { + vl[i] = get_int_b4(bq5_1->qs, iqs + i); + vh[i] = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i)); + u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i); + u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1); + } + + return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds); +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx; + + int v[VDR_Q8_0_Q8_1_MMVQ]; + int u[VDR_Q8_0_Q8_1_MMVQ]; + +#pragma unroll + for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { + v[i] = get_int_b2(bq8_0->qs, iqs + i); + u[i] = get_int_b4(bq8_1->qs, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds)); +} + +static __device__ __forceinline__ float vec_dot_q2_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx; + + const int bq8_offset = QR2_K * (iqs / QI8_1); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const uint8_t * scales = bq2_K->scales + scale_offset; + + const int v = get_int_b4(bq2_K->qs, iqs); + int u[QR2_K]; + float d8[QR2_K]; + +#pragma unroll + for (int i = 0; i < QR2_K; ++ i) { + u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); +} + +static __device__ __forceinline__ float vec_dot_q3_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx; + + const int bq8_offset = QR3_K * (iqs / (QI3_K/2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2); + + const float d = bq3_K->d; + + const int vl = get_int_b2(bq3_K->qs, iqs); + + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + i].ds); + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); +} + +static __device__ __forceinline__ float vec_dot_q4_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx; + + int v[2]; + int u[2*QR4_K]; + float d8[QR4_K]; + + // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6 + const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2)); + + // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12 + // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44 + // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76 + // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108 + + const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + v[0] = q4[0]; + v[1] = q4[4]; + + const uint16_t * scales = (const uint16_t *)bq4_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR4_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); +} + +static __device__ __forceinline__ float vec_dot_q5_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx; + + int vl[2]; + int vh[2]; + int u[2*QR5_K]; + float d8[QR5_K]; + + const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2)); + const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4)); + const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4)); + + vl[0] = ql[0]; + vl[1] = ql[4]; + + vh[0] = qh[0] >> bq8_offset; + vh[1] = qh[4] >> bq8_offset; + + const uint16_t * scales = (const uint16_t *)bq5_K->scales; + uint16_t aux[2]; + const int j = bq8_offset/2; + if (j < 2) { + aux[0] = scales[j+0] & 0x3f3f; + aux[1] = scales[j+2] & 0x3f3f; + } else { + aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2); + aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2); + } + const uint8_t * sc = (const uint8_t *)aux; + const uint8_t * m = sc + 2; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; + d8[i] = __low2float(bq8i->ds); + + const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); + u[2*i+0] = q8[0]; + u[2*i+1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8); +} + +static __device__ __forceinline__ float vec_dot_q6_K_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx; + + const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4); + const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8); + const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4)); + + const int vl = get_int_b2(bq6_K->ql, iqs); + const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift; + + const int8_t * scales = bq6_K->scales + scale_offset; + + int u[QR6_K]; + float d8[QR6_K]; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); + d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds); + } + + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); +} + +#define VDR_IQ2_XXS_Q8_1_MMVQ 2 +#define VDR_IQ2_XXS_Q8_1_MMQ 2 + +static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx; + + const int q2 = get_int_b2(bq2->qs, iqs); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1); + + int sumi = 0; +#pragma unroll + for (int k0 = 0; k0 < 8; k0 += 2) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0); + sumi = ggml_cuda_dp4a(grid0, u0, sumi); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1); + sumi = ggml_cuda_dp4a(grid1, u1, sumi); + } + + const int ls = aux32 >> 28; + sumi = (ls*sumi + sumi/2)/4; + const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); + return d * sumi; +} + +#define VDR_IQ2_XS_Q8_1_MMVQ 2 +#define VDR_IQ2_XS_Q8_1_MMQ 2 + +static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx; + + const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + const int ls0 = bq2->scales[iqs/2] & 0x0F; + const int ls1 = bq2->scales[iqs/2] >> 4; + + int sumi0 = 0; + int sumi1 = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + if (l0 < 4) { + sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0); + sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0); + } else { + sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1); + sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1); + } + } + const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4; + const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); + return d * sumi; +} + +#define VDR_IQ2_S_Q8_1_MMVQ 2 +#define VDR_IQ2_S_Q8_1_MMQ 2 + +static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx; + + const int qs_packed = get_int_b2(bq2->qs, iqs/2); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq2->qh[iqs/2]; + + const int signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + + const int ls0 = bq2->scales[iqs/2] & 0x0F; + const int ls1 = bq2->scales[iqs/2] >> 4; + + int sumi0 = 0; + int sumi1 = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + if (l0 < 4) { + sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0); + sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0); + } else { + sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1); + sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1); + } + } + const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4; + + const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); + return d * sumi; +} + +#define VDR_IQ3_XXS_Q8_1_MMVQ 2 +#define VDR_IQ3_XXS_Q8_1_MMQ 2 + +static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx; + + const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2); + + int sumi = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + sumi = ggml_cuda_dp4a(grid_l, u0, sumi); + sumi = ggml_cuda_dp4a(grid_h, u1, sumi); + } + + const int ls = aux32 >> 28; + sumi = (ls*sumi + sumi/2)/2; + const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds); + return d * sumi; +} + +#define VDR_IQ3_S_Q8_1_MMVQ 2 +#define VDR_IQ3_S_Q8_1_MMQ 2 + +// TODO: don't use lookup table for signs +static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx; + + const int2 qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq3->qh[iqs/2]; + + const int signs_packed_32 = get_int_b2(bq3->signs, iqs/2); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + + int sumi = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)], + iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + + const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); + + sumi = ggml_cuda_dp4a(grid_l, u0, sumi); + sumi = ggml_cuda_dp4a(grid_h, u1, sumi); + } + + sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F); + + const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds); + return d * sumi; +} + +#define VDR_IQ1_S_Q8_1_MMVQ 1 +#define VDR_IQ1_S_Q8_1_MMQ 1 + +static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx; + + const int qs_packed = get_int_b2(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bq1->qh[iqs]; + + int sumi = 0; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi = ggml_cuda_dp4a(grid0, u0, sumi); + sumi = ggml_cuda_dp4a(grid1, u1, sumi); + } + + const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + const float2 ds = __half22float2(bq8_1[iqs].ds); + return d1q * (ds.x*sumi + ds.y*delta); +} + +#define VDR_IQ1_M_Q8_1_MMVQ 1 +#define VDR_IQ1_M_Q8_1_MMQ 1 + +static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx; + + const int qs_packed = get_int_b4(bq1->qs, iqs); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + int sumi[2] = {0}; + float sumf[2] = {0.0f}; +#pragma unroll + for (int l0 = 0; l0 < 8; l0 += 2) { + const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2)); + + const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + + const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0); + const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1); + + sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]); + sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]); + + const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08); + int sumy = 0; + sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy); + sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy); + sumf[l0/4] += delta*sumy; + } + + const uint16_t * sc = (const uint16_t *) bq1->scales; + + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000); + const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds); + + const int tmp = sc[iqs/2] >> (6*(iqs%2)); + const int sc0 = 2*((tmp >> 0) & 0x07) + 1; + const int sc1 = 2*((tmp >> 3) & 0x07) + 1; + return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1); +} + +#define VDR_IQ4_NL_Q8_1_MMVQ 2 +#define VDR_IQ4_NL_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx; + + const int * q8 = (const int *) bq8_1->qs + iqs; + + int sumi = 0; +#pragma unroll + for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) { + const int aux_q4 = get_int_b2(bq4->qs, iqs + l); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + + sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi); + sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi); + } + + const float d = __half2float(bq4->d) * __low2float(bq8_1->ds); + return d * sumi; +} + +#define VDR_IQ4_XS_Q8_1_MMVQ 4 +#define VDR_IQ4_XS_Q8_1_MMQ 4 + +static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx; + + int sumi = 0; +#pragma unroll + for (int j = 0; j < 4; ++j) { + const int aux_q4 = get_int_b4(bq4->qs, iqs + j); + const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl); + + const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0); + const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4); + + sumi = ggml_cuda_dp4a(v.x, u0, sumi); + sumi = ggml_cuda_dp4a(v.y, u1, sumi); + } + + const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4); + sumi *= ls - 32; + + const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds); + return d * sumi; +} diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h b/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h new file mode 100644 index 0000000..ba032cf --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h @@ -0,0 +1,23 @@ +#pragma once + +#include <cuda_runtime.h> +#include <cuda.h> +#include <cublas_v2.h> +#include <cuda_bf16.h> +#include <cuda_fp16.h> + +#if CUDART_VERSION >= 12050 +#include <cuda_fp8.h> +#endif // CUDART_VERSION >= 12050 + +#if CUDART_VERSION >= 12080 +#include <cuda_fp4.h> +#endif // CUDART_VERSION >= 12080 + +#if CUDART_VERSION < 11020 +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define cublasComputeType_t cudaDataType_t +#endif // CUDART_VERSION < 11020 diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h b/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h new file mode 100644 index 0000000..5cc1b54 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h @@ -0,0 +1,278 @@ +#pragma once + +#define HIP_DISABLE_WARP_SYNC_BUILTINS 1 +#include <hip/hip_runtime.h> +#include <hipblas/hipblas.h> +#include <hip/hip_fp16.h> +#include <hip/hip_bf16.h> + +#if defined(GGML_HIP_ROCWMMA_FATTN) +#include <rocwmma/rocwmma-version.hpp> +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_16BF HIPBLAS_R_16B +#define CUDA_R_32F HIPBLAS_R_32F +#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended +#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned +#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite +#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) +#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define __all_sync(mask, var) __all(var) +#define __any_sync(mask, var) __any(var) +#define cublasStrsmBatched hipblasStrsmBatched +#define cublasCreate hipblasCreate +#define cublasDestroy hipblasDestroy +#define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cublasOperation_t hipblasOperation_t +#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch +#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEventSynchronize hipEventSynchronize +#define cudaEvent_t hipEvent_t +#define cudaEventDestroy hipEventDestroy +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaHostRegister hipHostRegister +#define cudaHostRegisterPortable hipHostRegisterPortable +#define cudaHostRegisterReadOnly hipHostRegisterReadOnly +#define cudaHostUnregister hipHostUnregister +#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel +#define cudaLaunchHostFunc hipLaunchHostFunc +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMallocManaged hipMallocManaged +#define cudaMemAdvise hipMemAdvise +#define cudaMemcpy hipMemcpy +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyPeerAsync hipMemcpyPeerAsync +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaMemsetAsync hipMemsetAsync +#define cudaMemGetInfo hipMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cuDeviceGet hipDeviceGet +#define CUdevice hipDevice_t +#define CUdeviceptr hipDeviceptr_t +#define cuMemUnmap hipMemUnmap +#define CUmemAccessDesc hipMemAccessDesc +#define cuMemAddressFree hipMemAddressFree +#define cuMemRelease hipMemRelease +#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t +#define cuMemCreate hipMemCreate +#define cuMemAddressReserve hipMemAddressReserve +#define cuMemMap hipMemMap +#define cuMemSetAccess hipMemSetAccess +#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity +#define CUmemAllocationProp hipMemAllocationProp +#define cuDeviceGetAttribute hipDeviceGetAttribute +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamDestroy hipStreamDestroy +#define cudaStreamFireAndForget hipStreamFireAndForget +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamPerThread hipStreamPerThread +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent hipStreamWaitEvent +#define cudaGraphExec_t hipGraphExec_t +#define cudaGraphNode_t hipGraphNode_t +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaKernelNodeParams hipKernelNodeParams +#define cudaGraphExecDestroy hipGraphExecDestroy +#define cudaGraphLaunch hipGraphLaunch +#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure +#define cudaGraphExecUpdateResult hipGraphExecUpdateResult +#define cudaGraphNodeType hipGraphNodeType +#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel +#define cudaGraphInstantiate hipGraphInstantiate +#define cudaStreamEndCapture hipStreamEndCapture +#define cudaGraphDestroy hipGraphDestroy +#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams +#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction +#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams +#define cudaGraphNodeGetType hipGraphNodeGetType +#define cudaGraphGetNodes hipGraphGetNodes +#define cudaGraphExecUpdate hipGraphExecUpdate +#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed +#define cudaStreamBeginCapture hipStreamBeginCapture +#define cudaGraph_t hipGraph_t +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +#define cudaFuncSetAttribute hipFuncSetAttribute +#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize +#define __trap() do { abort(); __builtin_unreachable(); } while(0) +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED +#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED +#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE +#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH +#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR +#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED +#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR +#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED + +#if HIP_VERSION >= 60500000 +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F +#define cublasComputeType_t hipblasComputeType_t +#define cudaDataType_t hipDataType +#else +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define cublasComputeType_t hipblasDatatype_t +#define cudaDataType_t hipblasDatatype_t +#endif // HIP_VERSION >= 6050000 + +#if !defined(__HIP_PLATFORM_AMD__) +#error "The HIP backend supports only AMD targets" +#endif // !defined(__HIP_PLATFORM_AMD__) + +#define __CUDA_ARCH__ 1300 + +#if defined(__gfx900__) || defined(__gfx906__) +#define GCN5 +#endif // defined(__gfx900__) || defined(__gfx906__) + +#if defined(__gfx803__) +#define GCN4 +#endif // defined(__gfx803__) + +#if defined(GCN5) || defined(GCN4) +#define GCN +#endif // defined(GCN5) || defined(GCN4) + +#if defined(__gfx942__) +#define CDNA3 +#endif // defined(__gfx942__) + +#if defined(__gfx90a__) +#define CDNA2 +#endif // defined(__gfx90a__) + +#if defined(__gfx908__) +#define CDNA1 +#endif // defined(__gfx908__) + +#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#define CDNA // For the entire family +#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) + +#if defined(__GFX12__) +#define RDNA4 +#endif // defined(__GFX12__) + +#if defined(__GFX11__) +#define RDNA3 +#endif // defined(__GFX11__) + +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) +#define RDNA2 +#endif + +#if defined(__gfx1010__) || defined(__gfx1012__) +#define RDNA1 +#endif // defined(__gfx1010__) || defined(__gfx1012__) + +#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) +#define RDNA // For the entire family +#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1) + +#ifndef __has_builtin + #define __has_builtin(x) 0 +#endif + +typedef __hip_bfloat16 nv_bfloat16; +typedef __hip_bfloat162 nv_bfloat162; + +typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); +typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); +static __device__ __forceinline__ int __vsubss4(const int a, const int b) { + const int8x4_t va = reinterpret_cast<const int8x4_t&>(a); + const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b); +#if __has_builtin(__builtin_elementwise_sub_sat) + const int8x4_t c = __builtin_elementwise_sub_sat(va, vb); + return reinterpret_cast<const int &>(c); +#else + int8x4_t c; + int16_t tmp; +#pragma unroll + for (int i = 0; i < 4; i++) { + tmp = va[i] - vb[i]; + if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max(); + if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min(); + c[i] = tmp; + } + return reinterpret_cast<int &>(c); +#endif // __has_builtin(__builtin_elementwise_sub_sat) +} + +static __device__ __forceinline__ int __vsub4(const int a, const int b) { + return __vsubss4(a, b); +} + +static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a); + const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0xff : 0x00; + } + return c; +} + +static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) { + const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a); + const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b); + unsigned int c; + uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c); +#pragma unroll + for (int i = 0; i < 4; ++i) { + vc[i] = va[i] == vb[i] ? 0x00 : 0xff; + } + return c; +} diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h b/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h new file mode 100644 index 0000000..1abb8ac --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h @@ -0,0 +1,147 @@ +#pragma once + +#include <musa_runtime.h> +#include <musa.h> +#include <mublas.h> +#include <musa_bf16.h> +#include <musa_fp16.h> +#define CUBLAS_COMPUTE_16F CUDA_R_16F +#define CUBLAS_COMPUTE_32F CUDA_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F +#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT +#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N MUBLAS_OP_N +#define CUBLAS_OP_T MUBLAS_OP_T +#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH +#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT +#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER +#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT +#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH +#define CUDA_R_16F MUSA_R_16F +#define CUDA_R_16BF MUSA_R_16BF +#define CUDA_R_32F MUSA_R_32F +#define cublasStrsmBatched mublasStrsmBatched +#define cublasComputeType_t cudaDataType_t +#define cublasCreate mublasCreate +#define cublasDestroy mublasDestroy +#define cublasGemmEx mublasGemmEx +#define cublasGemmBatchedEx mublasGemmBatchedEx +#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx +#define cublasHandle_t mublasHandle_t +#define cublasSetMathMode mublasSetMathMode +#define cublasSetStream mublasSetStream +#define cublasSgemm mublasSgemm +#define cublasStatus_t mublasStatus_t +#define cublasOperation_t mublasOperation_t +#define cublasGetStatusString mublasGetStatusString +#define cudaDataType_t musaDataType_t +#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer +#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess +#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceProp musaDeviceProp +#define cudaDeviceSynchronize musaDeviceSynchronize +#define cudaError_t musaError_t +#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled +#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled +#define cudaEventCreateWithFlags musaEventCreateWithFlags +#define cudaEventDisableTiming musaEventDisableTiming +#define cudaEventRecord musaEventRecord +#define cudaEventSynchronize musaEventSynchronize +#define cudaEvent_t musaEvent_t +#define cudaEventDestroy musaEventDestroy +#define cudaFree musaFree +#define cudaFreeHost musaFreeHost +#define cudaGetDevice musaGetDevice +#define cudaGetDeviceCount musaGetDeviceCount +#define cudaGetDeviceProperties musaGetDeviceProperties +#define cudaGetErrorString musaGetErrorString +#define cudaGetLastError musaGetLastError +#define cudaHostRegister musaHostRegister +#define cudaHostRegisterPortable musaHostRegisterPortable +#define cudaHostRegisterReadOnly musaHostRegisterReadOnly +#define cudaHostUnregister musaHostUnregister +#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel +#define cudaLaunchHostFunc musaLaunchHostFunc +#define cudaMalloc musaMalloc +#define cudaMallocHost musaMallocHost +#define cudaMallocManaged musaMallocManaged +#define cudaMemcpy musaMemcpy +#define cudaMemcpyAsync musaMemcpyAsync +#define cudaMemcpyPeerAsync musaMemcpyPeerAsync +#define cudaMemcpy2DAsync musaMemcpy2DAsync +#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost +#define cudaMemcpyHostToDevice musaMemcpyHostToDevice +#define cudaMemcpyKind musaMemcpyKind +#define cudaMemset musaMemset +#define cudaMemsetAsync musaMemsetAsync +#define cudaMemGetInfo musaMemGetInfo +#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaSetDevice musaSetDevice +#define cudaStreamCreateWithFlags musaStreamCreateWithFlags +#define cudaStreamDestroy musaStreamDestroy +#define cudaStreamFireAndForget musaStreamFireAndForget +#define cudaStreamNonBlocking musaStreamNonBlocking +#define cudaStreamPerThread musaStreamPerThread +#define cudaStreamSynchronize musaStreamSynchronize +#define cudaStreamWaitEvent musaStreamWaitEvent +#define cudaStream_t musaStream_t +#define cudaSuccess musaSuccess + +// Additional mappings for MUSA virtual memory pool +#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED +#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE +#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED +#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED +#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE +#define CUdevice MUdevice +#define CUdeviceptr MUdeviceptr +#define CUmemAccessDesc MUmemAccessDesc +#define CUmemAllocationProp MUmemAllocationProp +#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle +#define cuDeviceGet muDeviceGet +#define cuDeviceGetAttribute muDeviceGetAttribute +#define cuMemAddressFree muMemAddressFree +#define cuMemAddressReserve muMemAddressReserve +#define cuMemCreate muMemCreate +#define cuMemGetAllocationGranularity muMemGetAllocationGranularity +#define cuMemMap muMemMap +#define cuMemRelease muMemRelease +#define cuMemSetAccess muMemSetAccess +#define cuMemUnmap muMemUnmap +#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize +#define cudaFuncSetAttribute musaFuncSetAttribute +#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms +#define make_cudaExtent make_musaExtent +#define make_cudaPitchedPtr make_musaPitchedPtr + +// Additional mappings for MUSA graphs +#define CUDA_SUCCESS MUSA_SUCCESS +#define CUresult MUresult +#define cuGetErrorString muGetErrorString +#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure +#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction +#define cudaGraphDestroy musaGraphDestroy +#define cudaGraphExecDestroy musaGraphExecDestroy +#define cudaGraphExec_t musaGraphExec_t +#define cudaGraphExecUpdate musaGraphExecUpdate +#define cudaGraphExecUpdateResult musaGraphExecUpdateResult +#define cudaGraphGetNodes musaGraphGetNodes +#define cudaGraphInstantiate musaGraphInstantiate +#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams +#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams +#define cudaGraphLaunch musaGraphLaunch +#define cudaGraphNodeGetType musaGraphNodeGetType +#define cudaGraphNode_t musaGraphNode_t +#define cudaGraphNodeType musaGraphNodeType +#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel +#define cudaGraph_t musaGraph_t +#define cudaKernelNodeParams musaKernelNodeParams +#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed +#define cudaStreamBeginCapture musaStreamBeginCapture +#define cudaStreamEndCapture musaStreamEndCapture +#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor + +typedef __mt_bfloat16 nv_bfloat16; +typedef __mt_bfloat162 nv_bfloat162; diff --git a/llama.cpp/ggml/src/ggml-cuda/wkv.cu b/llama.cpp/ggml/src/ggml-cuda/wkv.cu new file mode 100644 index 0000000..d2fced7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/wkv.cu @@ -0,0 +1,199 @@ +#include "common.cuh" +#include "wkv.cuh" + +template <int block_size> +static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size]; + + #pragma unroll + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; + } + + __syncthreads(); + _tf[tid] = tf[head_i * head_size + tid]; + __syncthreads(); + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + __syncthreads(); + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& k = (float4&)(_k[j]); + const float4& r = (float4&)(_r[j]); + const float4& tf = (float4&)(_tf[j]); + const float4& td = (float4&)(_td[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + y += r.x * (tf.x * kv.x + s.x); + y += r.y * (tf.y * kv.y + s.y); + y += r.z * (tf.z * kv.z + s.z); + y += r.w * (tf.w * kv.w + s.w); + + s.x = s.x * td.x + kv.x; + s.y = s.y * td.y + kv.y; + s.z = s.z * td.z + kv.z; + s.w = s.w * td.w + kv.w; + } + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; + } +} + +template <int block_size> +static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int head_size = block_size; + const int batch_i = bid / H; + const int head_i = bid % H; + const int state_size = C * head_size; + const int n_seq_tokens = T / B; + + float state[head_size]; + __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size]; + +#ifndef GGML_USE_MUSA + #pragma unroll +#endif + for (int i = 0; i < head_size; i++) { + state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i]; + } + + for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) { + __syncthreads(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + __syncthreads(); + + float sa = 0; + #pragma unroll + for (int j = 0; j < head_size; j += 4) + { + const float4& a = (float4&)(_a[j]); + const float4& s = (float4&)(state[j]); + sa += a.x * s.x; + sa += a.y * s.y; + sa += a.z * s.z; + sa += a.w * s.w; + } + + const float _v = v[t]; + float y = 0; + for (int j = 0; j < head_size; j += 4) { + const float4& r = (float4&)(_r[j]); + const float4& w = (float4&)(_w[j]); + const float4& k = (float4&)(_k[j]); + const float4& b = (float4&)(_b[j]); + float4& s = (float4&)(state[j]); + float4 kv; + + kv.x = k.x * _v; + kv.y = k.y * _v; + kv.z = k.z * _v; + kv.w = k.w * _v; + + s.x = s.x * w.x + kv.x + sa * b.x; + s.y = s.y * w.y + kv.y + sa * b.y; + s.z = s.z * w.z + kv.z + sa * b.z; + s.w = s.w * w.w + kv.w + sa * b.w; + + y += s.x * r.x; + y += s.y * r.y; + y += s.z * r.z; + y += s.w * r.w; + } + dst[t] = y; + } + + #pragma unroll + for (int i = 0; i < head_size; i++) { + dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i]; + } +} + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * k_d = (const float *)dst->src[0]->data; + const float * v_d = (const float *)dst->src[1]->data; + const float * r_d = (const float *)dst->src[2]->data; + const float * tf_d = (const float *)dst->src[3]->data; + const float * td_d = (const float *)dst->src[4]->data; + const float * s_d = (const float *)dst->src[5]->data; + + const int64_t B = dst->src[5]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2); + + if (C / H == CUDA_WKV_BLOCK_SIZE) { + rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); + } else { + rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d); + } +} + +void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const float * r_d = (const float *)dst->src[0]->data; + const float * w_d = (const float *)dst->src[1]->data; + const float * k_d = (const float *)dst->src[2]->data; + const float * v_d = (const float *)dst->src[3]->data; + const float * a_d = (const float *)dst->src[4]->data; + const float * b_d = (const float *)dst->src[5]->data; + const float * s_d = (const float *)dst->src[6]->data; + + const int64_t B = dst->src[6]->ne[1]; + const int64_t T = dst->src[0]->ne[2]; + const int64_t C = dst->ne[0]; + const int64_t H = dst->src[0]->ne[1]; + + float * dst_d = (float *)dst->data; + + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); + GGML_ASSERT(C % H == 0); + GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2); + + if (C / H == CUDA_WKV_BLOCK_SIZE) { + rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } else { + rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d); + } +} diff --git a/llama.cpp/ggml/src/ggml-cuda/wkv.cuh b/llama.cpp/ggml/src/ggml-cuda/wkv.cuh new file mode 100644 index 0000000..9623dd7 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/wkv.cuh @@ -0,0 +1,7 @@ +#include "common.cuh" + +#define CUDA_WKV_BLOCK_SIZE 64 + +void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
