diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-musa/mudnn.cu | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-musa/mudnn.cu')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-musa/mudnn.cu | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-musa/mudnn.cu b/llama.cpp/ggml/src/ggml-musa/mudnn.cu new file mode 100644 index 0000000..020c170 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-musa/mudnn.cu @@ -0,0 +1,112 @@ +#include <mutex> +#include <mudnn.h> + +#include "mudnn.cuh" + +namespace mudnn = musa::dnn; + +// Returns a human-readable error string for mudnn::Status +const char* mudnnGetErrorString(mudnn::Status err) { + switch (err) { + case mudnn::Status::SUCCESS: + return "Success"; + case mudnn::Status::INVALID_PARAMETER: + return "Invalid parameter"; + case mudnn::Status::NOT_INITIALIZED: + return "Not initialized"; + case mudnn::Status::ALLOC_FAILED: + return "Allocation failed"; + case mudnn::Status::NOT_SUPPORTED: + return "Not supported"; + case mudnn::Status::INTERNAL_ERROR: + return "Internal error"; + case mudnn::Status::ARCH_MISMATCH: + return "Architecture mismatch"; + case mudnn::Status::EXECUTION_FAILED: + return "Execution failed"; + default: + return "Unknown mudnn status"; + } +} + +// Error checking macro for MUDNN calls +#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString) + +namespace { + // Thread-safe cache for mudnn::Handle objects per device + std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache; + std::mutex handle_cache_mutex; + + mudnn::Handle* get_cached_handle(int device_id) { + std::lock_guard<std::mutex> lock(handle_cache_mutex); + auto it = handle_cache.find(device_id); + if (it != handle_cache.end()) { + return it->second.get(); + } + auto handle = std::make_unique<mudnn::Handle>(device_id); + mudnn::Handle* handle_ptr = handle.get(); + handle_cache[device_id] = std::move(handle); + return handle_ptr; + } +} + +// Extracts dimensions and strides from a ggml_tensor +int get_ggml_dims_and_strides(const ggml_tensor* tensor, + std::vector<int64_t>& dims, + std::vector<int64_t>& strides) { + const int ndims = ggml_n_dims(tensor); + const size_t element_size = ggml_element_size(tensor); + + dims.resize(ndims); + strides.resize(ndims); + + for (int i = 0; i < ndims; ++i) { + dims[i] = tensor->ne[i]; + strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size); + } + return ndims; +} + +// Converts ggml_type to mudnn::Tensor::Type +mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) { + switch (type) { + case GGML_TYPE_F32: + return mudnn::Tensor::Type::FLOAT; + case GGML_TYPE_F16: + return mudnn::Tensor::Type::HALF; + + // TODO: Add support for other types + + default: + MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED); + } + + return mudnn::Tensor::Type::FLOAT; // Default fallback +} + +// Asynchronous memory copy using mudnn::Unary::IDENTITY +musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) { + mudnn::Tensor tensor_dst, tensor_src; + + MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type))); + MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type))); + + std::vector<int64_t> dims, strides; + const int ndims = get_ggml_dims_and_strides(src, dims, strides); + + MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data())); + MUDNN_CHECK(tensor_dst.SetAddr(dst->data)); + MUDNN_CHECK(tensor_src.SetAddr(src->data)); + + mudnn::Unary op; + MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY)); + MUDNN_CHECK(op.SetAlpha(0.0f)); + MUDNN_CHECK(op.SetBeta(0.0f)); + + mudnn::Handle* handle = get_cached_handle(ctx.device); + MUDNN_CHECK(handle->SetStream(ctx.stream())); + MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src)); + + return musaSuccess; +} |
