summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cuda/vendors
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda/vendors
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/vendors')
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h23
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/vendors/hip.h278
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/vendors/musa.h147
3 files changed, 448 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h b/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h
new file mode 100644
index 0000000..ba032cf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h
@@ -0,0 +1,23 @@
1#pragma once
2
3#include <cuda_runtime.h>
4#include <cuda.h>
5#include <cublas_v2.h>
6#include <cuda_bf16.h>
7#include <cuda_fp16.h>
8
9#if CUDART_VERSION >= 12050
10#include <cuda_fp8.h>
11#endif // CUDART_VERSION >= 12050
12
13#if CUDART_VERSION >= 12080
14#include <cuda_fp4.h>
15#endif // CUDART_VERSION >= 12080
16
17#if CUDART_VERSION < 11020
18#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
19#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
20#define CUBLAS_COMPUTE_16F CUDA_R_16F
21#define CUBLAS_COMPUTE_32F CUDA_R_32F
22#define cublasComputeType_t cudaDataType_t
23#endif // CUDART_VERSION < 11020
diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h b/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h
new file mode 100644
index 0000000..5cc1b54
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h
@@ -0,0 +1,278 @@
1#pragma once
2
3#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
4#include <hip/hip_runtime.h>
5#include <hipblas/hipblas.h>
6#include <hip/hip_fp16.h>
7#include <hip/hip_bf16.h>
8
9#if defined(GGML_HIP_ROCWMMA_FATTN)
10#include <rocwmma/rocwmma-version.hpp>
11#endif // defined(GGML_HIP_ROCWMMA_FATTN)
12
13#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
14#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
15#define CUBLAS_OP_N HIPBLAS_OP_N
16#define CUBLAS_OP_T HIPBLAS_OP_T
17#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
18#define CUBLAS_TF32_TENSOR_OP_MATH 0
19#define CUDA_R_16F HIPBLAS_R_16F
20#define CUDA_R_16BF HIPBLAS_R_16B
21#define CUDA_R_32F HIPBLAS_R_32F
22#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
23#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER
24#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT
25#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
26#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
27#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
28#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
29#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
30#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
31#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
32#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
33#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
34#define __all_sync(mask, var) __all(var)
35#define __any_sync(mask, var) __any(var)
36#define cublasStrsmBatched hipblasStrsmBatched
37#define cublasCreate hipblasCreate
38#define cublasDestroy hipblasDestroy
39#define cublasGemmEx hipblasGemmEx
40#define cublasGemmBatchedEx hipblasGemmBatchedEx
41#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
42#define cublasHandle_t hipblasHandle_t
43#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
44#define cublasSetStream hipblasSetStream
45#define cublasSgemm hipblasSgemm
46#define cublasStatus_t hipblasStatus_t
47#define cublasOperation_t hipblasOperation_t
48#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
49#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
50#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
51#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
52#define cudaDeviceGetAttribute hipDeviceGetAttribute
53#define cudaDeviceProp hipDeviceProp_t
54#define cudaDeviceSynchronize hipDeviceSynchronize
55#define cudaError_t hipError_t
56#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
57#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
58#define cudaEventCreateWithFlags hipEventCreateWithFlags
59#define cudaEventDisableTiming hipEventDisableTiming
60#define cudaEventRecord hipEventRecord
61#define cudaEventSynchronize hipEventSynchronize
62#define cudaEvent_t hipEvent_t
63#define cudaEventDestroy hipEventDestroy
64#define cudaFree hipFree
65#define cudaFreeHost hipHostFree
66#define cudaGetDevice hipGetDevice
67#define cudaGetDeviceCount hipGetDeviceCount
68#define cudaGetDeviceProperties hipGetDeviceProperties
69#define cudaGetErrorString hipGetErrorString
70#define cudaGetLastError hipGetLastError
71#define cudaHostRegister hipHostRegister
72#define cudaHostRegisterPortable hipHostRegisterPortable
73#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
74#define cudaHostUnregister hipHostUnregister
75#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
76#define cudaLaunchHostFunc hipLaunchHostFunc
77#define cudaMalloc hipMalloc
78#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
79#define cudaMallocManaged hipMallocManaged
80#define cudaMemAdvise hipMemAdvise
81#define cudaMemcpy hipMemcpy
82#define cudaMemcpyAsync hipMemcpyAsync
83#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
84#define cudaMemcpy2DAsync hipMemcpy2DAsync
85#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
86#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
87#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
88#define cudaMemcpyKind hipMemcpyKind
89#define cudaMemset hipMemset
90#define cudaMemsetAsync hipMemsetAsync
91#define cudaMemGetInfo hipMemGetInfo
92#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
93#define cudaSetDevice hipSetDevice
94#define cuDeviceGet hipDeviceGet
95#define CUdevice hipDevice_t
96#define CUdeviceptr hipDeviceptr_t
97#define cuMemUnmap hipMemUnmap
98#define CUmemAccessDesc hipMemAccessDesc
99#define cuMemAddressFree hipMemAddressFree
100#define cuMemRelease hipMemRelease
101#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
102#define cuMemCreate hipMemCreate
103#define cuMemAddressReserve hipMemAddressReserve
104#define cuMemMap hipMemMap
105#define cuMemSetAccess hipMemSetAccess
106#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
107#define CUmemAllocationProp hipMemAllocationProp
108#define cuDeviceGetAttribute hipDeviceGetAttribute
109#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
110#define cudaStreamDestroy hipStreamDestroy
111#define cudaStreamFireAndForget hipStreamFireAndForget
112#define cudaStreamNonBlocking hipStreamNonBlocking
113#define cudaStreamPerThread hipStreamPerThread
114#define cudaStreamSynchronize hipStreamSynchronize
115#define cudaStreamWaitEvent hipStreamWaitEvent
116#define cudaGraphExec_t hipGraphExec_t
117#define cudaGraphNode_t hipGraphNode_t
118#define cudaKernelNodeParams hipKernelNodeParams
119#define cudaKernelNodeParams hipKernelNodeParams
120#define cudaGraphExecDestroy hipGraphExecDestroy
121#define cudaGraphLaunch hipGraphLaunch
122#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
123#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
124#define cudaGraphNodeType hipGraphNodeType
125#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
126#define cudaGraphInstantiate hipGraphInstantiate
127#define cudaStreamEndCapture hipStreamEndCapture
128#define cudaGraphDestroy hipGraphDestroy
129#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
130#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
131#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
132#define cudaGraphNodeGetType hipGraphNodeGetType
133#define cudaGraphGetNodes hipGraphGetNodes
134#define cudaGraphExecUpdate hipGraphExecUpdate
135#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
136#define cudaStreamBeginCapture hipStreamBeginCapture
137#define cudaGraph_t hipGraph_t
138#define cudaStream_t hipStream_t
139#define cudaSuccess hipSuccess
140#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
141#define cudaFuncSetAttribute hipFuncSetAttribute
142#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
143#define __trap() do { abort(); __builtin_unreachable(); } while(0)
144#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
145#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
146#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
147#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
148#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
149#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
150#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
151#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
152#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
153
154#if HIP_VERSION >= 60500000
155#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
156#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
157#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
158#define cublasComputeType_t hipblasComputeType_t
159#define cudaDataType_t hipDataType
160#else
161#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
162#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
163#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
164#define cublasComputeType_t hipblasDatatype_t
165#define cudaDataType_t hipblasDatatype_t
166#endif // HIP_VERSION >= 6050000
167
168#if !defined(__HIP_PLATFORM_AMD__)
169#error "The HIP backend supports only AMD targets"
170#endif // !defined(__HIP_PLATFORM_AMD__)
171
172#define __CUDA_ARCH__ 1300
173
174#if defined(__gfx900__) || defined(__gfx906__)
175#define GCN5
176#endif // defined(__gfx900__) || defined(__gfx906__)
177
178#if defined(__gfx803__)
179#define GCN4
180#endif // defined(__gfx803__)
181
182#if defined(GCN5) || defined(GCN4)
183#define GCN
184#endif // defined(GCN5) || defined(GCN4)
185
186#if defined(__gfx942__)
187#define CDNA3
188#endif // defined(__gfx942__)
189
190#if defined(__gfx90a__)
191#define CDNA2
192#endif // defined(__gfx90a__)
193
194#if defined(__gfx908__)
195#define CDNA1
196#endif // defined(__gfx908__)
197
198#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
199#define CDNA // For the entire family
200#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
201
202#if defined(__GFX12__)
203#define RDNA4
204#endif // defined(__GFX12__)
205
206#if defined(__GFX11__)
207#define RDNA3
208#endif // defined(__GFX11__)
209
210#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
211 defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
212#define RDNA2
213#endif
214
215#if defined(__gfx1010__) || defined(__gfx1012__)
216#define RDNA1
217#endif // defined(__gfx1010__) || defined(__gfx1012__)
218
219#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
220#define RDNA // For the entire family
221#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
222
223#ifndef __has_builtin
224 #define __has_builtin(x) 0
225#endif
226
227typedef __hip_bfloat16 nv_bfloat16;
228typedef __hip_bfloat162 nv_bfloat162;
229
230typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
231typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
232static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
233 const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
234 const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
235#if __has_builtin(__builtin_elementwise_sub_sat)
236 const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
237 return reinterpret_cast<const int &>(c);
238#else
239 int8x4_t c;
240 int16_t tmp;
241#pragma unroll
242 for (int i = 0; i < 4; i++) {
243 tmp = va[i] - vb[i];
244 if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
245 if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
246 c[i] = tmp;
247 }
248 return reinterpret_cast<int &>(c);
249#endif // __has_builtin(__builtin_elementwise_sub_sat)
250}
251
252static __device__ __forceinline__ int __vsub4(const int a, const int b) {
253 return __vsubss4(a, b);
254}
255
256static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
257 const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
258 const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
259 unsigned int c;
260 uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
261#pragma unroll
262 for (int i = 0; i < 4; ++i) {
263 vc[i] = va[i] == vb[i] ? 0xff : 0x00;
264 }
265 return c;
266}
267
268static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
269 const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
270 const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
271 unsigned int c;
272 uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
273#pragma unroll
274 for (int i = 0; i < 4; ++i) {
275 vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
276 }
277 return c;
278}
diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h b/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h
new file mode 100644
index 0000000..1abb8ac
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h
@@ -0,0 +1,147 @@
1#pragma once
2
3#include <musa_runtime.h>
4#include <musa.h>
5#include <mublas.h>
6#include <musa_bf16.h>
7#include <musa_fp16.h>
8#define CUBLAS_COMPUTE_16F CUDA_R_16F
9#define CUBLAS_COMPUTE_32F CUDA_R_32F
10#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
11#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
12#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
13#define CUBLAS_OP_N MUBLAS_OP_N
14#define CUBLAS_OP_T MUBLAS_OP_T
15#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
16#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
17#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
18#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
19#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
20#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
21#define CUDA_R_16F MUSA_R_16F
22#define CUDA_R_16BF MUSA_R_16BF
23#define CUDA_R_32F MUSA_R_32F
24#define cublasStrsmBatched mublasStrsmBatched
25#define cublasComputeType_t cudaDataType_t
26#define cublasCreate mublasCreate
27#define cublasDestroy mublasDestroy
28#define cublasGemmEx mublasGemmEx
29#define cublasGemmBatchedEx mublasGemmBatchedEx
30#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
31#define cublasHandle_t mublasHandle_t
32#define cublasSetMathMode mublasSetMathMode
33#define cublasSetStream mublasSetStream
34#define cublasSgemm mublasSgemm
35#define cublasStatus_t mublasStatus_t
36#define cublasOperation_t mublasOperation_t
37#define cublasGetStatusString mublasGetStatusString
38#define cudaDataType_t musaDataType_t
39#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
40#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
41#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
42#define cudaDeviceProp musaDeviceProp
43#define cudaDeviceSynchronize musaDeviceSynchronize
44#define cudaError_t musaError_t
45#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
46#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
47#define cudaEventCreateWithFlags musaEventCreateWithFlags
48#define cudaEventDisableTiming musaEventDisableTiming
49#define cudaEventRecord musaEventRecord
50#define cudaEventSynchronize musaEventSynchronize
51#define cudaEvent_t musaEvent_t
52#define cudaEventDestroy musaEventDestroy
53#define cudaFree musaFree
54#define cudaFreeHost musaFreeHost
55#define cudaGetDevice musaGetDevice
56#define cudaGetDeviceCount musaGetDeviceCount
57#define cudaGetDeviceProperties musaGetDeviceProperties
58#define cudaGetErrorString musaGetErrorString
59#define cudaGetLastError musaGetLastError
60#define cudaHostRegister musaHostRegister
61#define cudaHostRegisterPortable musaHostRegisterPortable
62#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
63#define cudaHostUnregister musaHostUnregister
64#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
65#define cudaLaunchHostFunc musaLaunchHostFunc
66#define cudaMalloc musaMalloc
67#define cudaMallocHost musaMallocHost
68#define cudaMallocManaged musaMallocManaged
69#define cudaMemcpy musaMemcpy
70#define cudaMemcpyAsync musaMemcpyAsync
71#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
72#define cudaMemcpy2DAsync musaMemcpy2DAsync
73#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
74#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
75#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
76#define cudaMemcpyKind musaMemcpyKind
77#define cudaMemset musaMemset
78#define cudaMemsetAsync musaMemsetAsync
79#define cudaMemGetInfo musaMemGetInfo
80#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
81#define cudaSetDevice musaSetDevice
82#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
83#define cudaStreamDestroy musaStreamDestroy
84#define cudaStreamFireAndForget musaStreamFireAndForget
85#define cudaStreamNonBlocking musaStreamNonBlocking
86#define cudaStreamPerThread musaStreamPerThread
87#define cudaStreamSynchronize musaStreamSynchronize
88#define cudaStreamWaitEvent musaStreamWaitEvent
89#define cudaStream_t musaStream_t
90#define cudaSuccess musaSuccess
91
92// Additional mappings for MUSA virtual memory pool
93#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
94#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
95#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
96#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
97#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
98#define CUdevice MUdevice
99#define CUdeviceptr MUdeviceptr
100#define CUmemAccessDesc MUmemAccessDesc
101#define CUmemAllocationProp MUmemAllocationProp
102#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
103#define cuDeviceGet muDeviceGet
104#define cuDeviceGetAttribute muDeviceGetAttribute
105#define cuMemAddressFree muMemAddressFree
106#define cuMemAddressReserve muMemAddressReserve
107#define cuMemCreate muMemCreate
108#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
109#define cuMemMap muMemMap
110#define cuMemRelease muMemRelease
111#define cuMemSetAccess muMemSetAccess
112#define cuMemUnmap muMemUnmap
113#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
114#define cudaFuncSetAttribute musaFuncSetAttribute
115#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
116#define make_cudaExtent make_musaExtent
117#define make_cudaPitchedPtr make_musaPitchedPtr
118
119// Additional mappings for MUSA graphs
120#define CUDA_SUCCESS MUSA_SUCCESS
121#define CUresult MUresult
122#define cuGetErrorString muGetErrorString
123#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
124#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
125#define cudaGraphDestroy musaGraphDestroy
126#define cudaGraphExecDestroy musaGraphExecDestroy
127#define cudaGraphExec_t musaGraphExec_t
128#define cudaGraphExecUpdate musaGraphExecUpdate
129#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
130#define cudaGraphGetNodes musaGraphGetNodes
131#define cudaGraphInstantiate musaGraphInstantiate
132#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
133#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
134#define cudaGraphLaunch musaGraphLaunch
135#define cudaGraphNodeGetType musaGraphNodeGetType
136#define cudaGraphNode_t musaGraphNode_t
137#define cudaGraphNodeType musaGraphNodeType
138#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
139#define cudaGraph_t musaGraph_t
140#define cudaKernelNodeParams musaKernelNodeParams
141#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
142#define cudaStreamBeginCapture musaStreamBeginCapture
143#define cudaStreamEndCapture musaStreamEndCapture
144#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
145
146typedef __mt_bfloat16 nv_bfloat16;
147typedef __mt_bfloat162 nv_bfloat162;