1#include <mutex>
  2#include <mudnn.h>
  3
  4#include "mudnn.cuh"
  5
  6namespace mudnn = musa::dnn;
  7
  8// Returns a human-readable error string for mudnn::Status
  9const char* mudnnGetErrorString(mudnn::Status err) {
 10    switch (err) {
 11        case mudnn::Status::SUCCESS:
 12            return "Success";
 13        case mudnn::Status::INVALID_PARAMETER:
 14            return "Invalid parameter";
 15        case mudnn::Status::NOT_INITIALIZED:
 16            return "Not initialized";
 17        case mudnn::Status::ALLOC_FAILED:
 18            return "Allocation failed";
 19        case mudnn::Status::NOT_SUPPORTED:
 20            return "Not supported";
 21        case mudnn::Status::INTERNAL_ERROR:
 22            return "Internal error";
 23        case mudnn::Status::ARCH_MISMATCH:
 24            return "Architecture mismatch";
 25        case mudnn::Status::EXECUTION_FAILED:
 26            return "Execution failed";
 27        default:
 28            return "Unknown mudnn status";
 29    }
 30}
 31
 32// Error checking macro for MUDNN calls
 33#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)
 34
 35namespace {
 36    // Thread-safe cache for mudnn::Handle objects per device
 37    std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
 38    std::mutex handle_cache_mutex;
 39
 40    mudnn::Handle* get_cached_handle(int device_id) {
 41        std::lock_guard<std::mutex> lock(handle_cache_mutex);
 42        auto it = handle_cache.find(device_id);
 43        if (it != handle_cache.end()) {
 44            return it->second.get();
 45        }
 46        auto handle = std::make_unique<mudnn::Handle>(device_id);
 47        mudnn::Handle* handle_ptr = handle.get();
 48        handle_cache[device_id] = std::move(handle);
 49        return handle_ptr;
 50    }
 51}
 52
 53// Extracts dimensions and strides from a ggml_tensor
 54int get_ggml_dims_and_strides(const ggml_tensor* tensor,
 55                              std::vector<int64_t>& dims,
 56                              std::vector<int64_t>& strides) {
 57    const int ndims = ggml_n_dims(tensor);
 58    const size_t element_size = ggml_element_size(tensor);
 59
 60    dims.resize(ndims);
 61    strides.resize(ndims);
 62
 63    for (int i = 0; i < ndims; ++i) {
 64        dims[i] = tensor->ne[i];
 65        strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
 66    }
 67    return ndims;
 68}
 69
 70// Converts ggml_type to mudnn::Tensor::Type
 71mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
 72    switch (type) {
 73        case GGML_TYPE_F32:
 74            return mudnn::Tensor::Type::FLOAT;
 75        case GGML_TYPE_F16:
 76            return mudnn::Tensor::Type::HALF;
 77
 78        // TODO: Add support for other types
 79
 80        default:
 81            MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
 82    }
 83
 84    return mudnn::Tensor::Type::FLOAT; // Default fallback
 85}
 86
 87// Asynchronous memory copy using mudnn::Unary::IDENTITY
 88musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
 89    mudnn::Tensor tensor_dst, tensor_src;
 90
 91    MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
 92    MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));
 93
 94    std::vector<int64_t> dims, strides;
 95    const int ndims = get_ggml_dims_and_strides(src, dims, strides);
 96
 97    MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
 98    MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
 99    MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
100    MUDNN_CHECK(tensor_src.SetAddr(src->data));
101
102    mudnn::Unary op;
103    MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
104    MUDNN_CHECK(op.SetAlpha(0.0f));
105    MUDNN_CHECK(op.SetBeta(0.0f));
106
107    mudnn::Handle* handle = get_cached_handle(ctx.device);
108    MUDNN_CHECK(handle->SetStream(ctx.stream()));
109    MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));
110
111    return musaSuccess;
112}