1#pragma once
  2
  3#include <musa_runtime.h>
  4#include <musa.h>
  5#include <mublas.h>
  6#include <musa_bf16.h>
  7#include <musa_fp16.h>
  8#define CUBLAS_COMPUTE_16F CUDA_R_16F
  9#define CUBLAS_COMPUTE_32F CUDA_R_32F
 10#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
 11#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
 12#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
 13#define CUBLAS_OP_N MUBLAS_OP_N
 14#define CUBLAS_OP_T MUBLAS_OP_T
 15#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
 16#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
 17#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
 18#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
 19#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
 20#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
 21#define CUDA_R_16F  MUSA_R_16F
 22#define CUDA_R_16BF MUSA_R_16BF
 23#define CUDA_R_32F  MUSA_R_32F
 24#define cublasStrsmBatched mublasStrsmBatched
 25#define cublasComputeType_t cudaDataType_t
 26#define cublasCreate mublasCreate
 27#define cublasDestroy mublasDestroy
 28#define cublasGemmEx mublasGemmEx
 29#define cublasGemmBatchedEx mublasGemmBatchedEx
 30#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
 31#define cublasHandle_t mublasHandle_t
 32#define cublasSetMathMode mublasSetMathMode
 33#define cublasSetStream mublasSetStream
 34#define cublasSgemm mublasSgemm
 35#define cublasStatus_t mublasStatus_t
 36#define cublasOperation_t mublasOperation_t
 37#define cublasGetStatusString mublasGetStatusString
 38#define cudaDataType_t musaDataType_t
 39#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
 40#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
 41#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
 42#define cudaDeviceProp musaDeviceProp
 43#define cudaDeviceSynchronize musaDeviceSynchronize
 44#define cudaError_t musaError_t
 45#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
 46#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
 47#define cudaEventCreateWithFlags musaEventCreateWithFlags
 48#define cudaEventDisableTiming musaEventDisableTiming
 49#define cudaEventRecord musaEventRecord
 50#define cudaEventSynchronize musaEventSynchronize
 51#define cudaEvent_t musaEvent_t
 52#define cudaEventDestroy musaEventDestroy
 53#define cudaFree musaFree
 54#define cudaFreeHost musaFreeHost
 55#define cudaGetDevice musaGetDevice
 56#define cudaGetDeviceCount musaGetDeviceCount
 57#define cudaGetDeviceProperties musaGetDeviceProperties
 58#define cudaGetErrorString musaGetErrorString
 59#define cudaGetLastError musaGetLastError
 60#define cudaHostRegister musaHostRegister
 61#define cudaHostRegisterPortable musaHostRegisterPortable
 62#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
 63#define cudaHostUnregister musaHostUnregister
 64#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
 65#define cudaLaunchHostFunc musaLaunchHostFunc
 66#define cudaMalloc musaMalloc
 67#define cudaMallocHost musaMallocHost
 68#define cudaMallocManaged musaMallocManaged
 69#define cudaMemcpy musaMemcpy
 70#define cudaMemcpyAsync musaMemcpyAsync
 71#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
 72#define cudaMemcpy2DAsync musaMemcpy2DAsync
 73#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
 74#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
 75#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
 76#define cudaMemcpyKind musaMemcpyKind
 77#define cudaMemset musaMemset
 78#define cudaMemsetAsync musaMemsetAsync
 79#define cudaMemGetInfo musaMemGetInfo
 80#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
 81#define cudaSetDevice musaSetDevice
 82#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
 83#define cudaStreamDestroy musaStreamDestroy
 84#define cudaStreamFireAndForget musaStreamFireAndForget
 85#define cudaStreamNonBlocking musaStreamNonBlocking
 86#define cudaStreamPerThread musaStreamPerThread
 87#define cudaStreamSynchronize musaStreamSynchronize
 88#define cudaStreamWaitEvent musaStreamWaitEvent
 89#define cudaStream_t musaStream_t
 90#define cudaSuccess musaSuccess
 91
 92// Additional mappings for MUSA virtual memory pool
 93#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
 94#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
 95#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
 96#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
 97#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
 98#define CUdevice MUdevice
 99#define CUdeviceptr MUdeviceptr
100#define CUmemAccessDesc MUmemAccessDesc
101#define CUmemAllocationProp MUmemAllocationProp
102#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
103#define cuDeviceGet muDeviceGet
104#define cuDeviceGetAttribute muDeviceGetAttribute
105#define cuMemAddressFree muMemAddressFree
106#define cuMemAddressReserve muMemAddressReserve
107#define cuMemCreate muMemCreate
108#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
109#define cuMemMap muMemMap
110#define cuMemRelease muMemRelease
111#define cuMemSetAccess muMemSetAccess
112#define cuMemUnmap muMemUnmap
113#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
114#define cudaFuncSetAttribute musaFuncSetAttribute
115#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
116#define make_cudaExtent make_musaExtent
117#define make_cudaPitchedPtr make_musaPitchedPtr
118
119// Additional mappings for MUSA graphs
120#define CUDA_SUCCESS MUSA_SUCCESS
121#define CUresult MUresult
122#define cuGetErrorString muGetErrorString
123#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
124#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
125#define cudaGraphDestroy musaGraphDestroy
126#define cudaGraphExecDestroy musaGraphExecDestroy
127#define cudaGraphExec_t musaGraphExec_t
128#define cudaGraphExecUpdate musaGraphExecUpdate
129#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
130#define cudaGraphGetNodes musaGraphGetNodes
131#define cudaGraphInstantiate musaGraphInstantiate
132#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
133#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
134#define cudaGraphLaunch musaGraphLaunch
135#define cudaGraphNodeGetType musaGraphNodeGetType
136#define cudaGraphNode_t musaGraphNode_t
137#define cudaGraphNodeType musaGraphNodeType
138#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
139#define cudaGraph_t musaGraph_t
140#define cudaKernelNodeParams musaKernelNodeParams
141#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
142#define cudaStreamBeginCapture musaStreamBeginCapture
143#define cudaStreamEndCapture musaStreamEndCapture
144#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
145
146typedef __mt_bfloat16 nv_bfloat16;
147typedef __mt_bfloat162 nv_bfloat162;