summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-musa/mudnn.cu
blob: 020c1702c45c0ec4bae3caccb49d1d793b5fc460 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include <mutex>
#include <mudnn.h>

#include "mudnn.cuh"

namespace mudnn = musa::dnn;

// Returns a human-readable error string for mudnn::Status
const char* mudnnGetErrorString(mudnn::Status err) {
    switch (err) {
        case mudnn::Status::SUCCESS:
            return "Success";
        case mudnn::Status::INVALID_PARAMETER:
            return "Invalid parameter";
        case mudnn::Status::NOT_INITIALIZED:
            return "Not initialized";
        case mudnn::Status::ALLOC_FAILED:
            return "Allocation failed";
        case mudnn::Status::NOT_SUPPORTED:
            return "Not supported";
        case mudnn::Status::INTERNAL_ERROR:
            return "Internal error";
        case mudnn::Status::ARCH_MISMATCH:
            return "Architecture mismatch";
        case mudnn::Status::EXECUTION_FAILED:
            return "Execution failed";
        default:
            return "Unknown mudnn status";
    }
}

// Error checking macro for MUDNN calls
#define MUDNN_CHECK(err) CUDA_CHECK_GEN(err, mudnn::Status::SUCCESS, mudnnGetErrorString)

namespace {
    // Thread-safe cache for mudnn::Handle objects per device
    std::unordered_map<int, std::unique_ptr<mudnn::Handle>> handle_cache;
    std::mutex handle_cache_mutex;

    mudnn::Handle* get_cached_handle(int device_id) {
        std::lock_guard<std::mutex> lock(handle_cache_mutex);
        auto it = handle_cache.find(device_id);
        if (it != handle_cache.end()) {
            return it->second.get();
        }
        auto handle = std::make_unique<mudnn::Handle>(device_id);
        mudnn::Handle* handle_ptr = handle.get();
        handle_cache[device_id] = std::move(handle);
        return handle_ptr;
    }
}

// Extracts dimensions and strides from a ggml_tensor
int get_ggml_dims_and_strides(const ggml_tensor* tensor,
                              std::vector<int64_t>& dims,
                              std::vector<int64_t>& strides) {
    const int ndims = ggml_n_dims(tensor);
    const size_t element_size = ggml_element_size(tensor);

    dims.resize(ndims);
    strides.resize(ndims);

    for (int i = 0; i < ndims; ++i) {
        dims[i] = tensor->ne[i];
        strides[i] = tensor->nb[i] / static_cast<int64_t>(element_size);
    }
    return ndims;
}

// Converts ggml_type to mudnn::Tensor::Type
mudnn::Tensor::Type ggml_type_to_mudnn_type(ggml_type type) {
    switch (type) {
        case GGML_TYPE_F32:
            return mudnn::Tensor::Type::FLOAT;
        case GGML_TYPE_F16:
            return mudnn::Tensor::Type::HALF;

        // TODO: Add support for other types

        default:
            MUDNN_CHECK(mudnn::Status::NOT_SUPPORTED);
    }

    return mudnn::Tensor::Type::FLOAT; // Default fallback
}

// Asynchronous memory copy using mudnn::Unary::IDENTITY
musaError_t mudnnMemcpyAsync(ggml_backend_cuda_context& ctx, const ggml_tensor* dst, const ggml_tensor* src) {
    mudnn::Tensor tensor_dst, tensor_src;

    MUDNN_CHECK(tensor_dst.SetType(ggml_type_to_mudnn_type(dst->type)));
    MUDNN_CHECK(tensor_src.SetType(ggml_type_to_mudnn_type(src->type)));

    std::vector<int64_t> dims, strides;
    const int ndims = get_ggml_dims_and_strides(src, dims, strides);

    MUDNN_CHECK(tensor_dst.SetNdInfo(ndims, dims.data(), strides.data()));
    MUDNN_CHECK(tensor_src.SetNdInfo(ndims, dims.data(), strides.data()));
    MUDNN_CHECK(tensor_dst.SetAddr(dst->data));
    MUDNN_CHECK(tensor_src.SetAddr(src->data));

    mudnn::Unary op;
    MUDNN_CHECK(op.SetMode(mudnn::Unary::Mode::IDENTITY));
    MUDNN_CHECK(op.SetAlpha(0.0f));
    MUDNN_CHECK(op.SetBeta(0.0f));

    mudnn::Handle* handle = get_cached_handle(ctx.device);
    MUDNN_CHECK(handle->SetStream(ctx.stream()));
    MUDNN_CHECK(op.Run(*handle, tensor_dst, tensor_src));

    return musaSuccess;
}