1#pragma once
  2
  3#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
  4#include <hip/hip_runtime.h>
  5#include <hipblas/hipblas.h>
  6#include <hip/hip_fp16.h>
  7#include <hip/hip_bf16.h>
  8
  9#if defined(GGML_HIP_ROCWMMA_FATTN)
 10#include <rocwmma/rocwmma-version.hpp>
 11#endif // defined(GGML_HIP_ROCWMMA_FATTN)
 12
 13#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
 14#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
 15#define CUBLAS_OP_N HIPBLAS_OP_N
 16#define CUBLAS_OP_T HIPBLAS_OP_T
 17#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
 18#define CUBLAS_TF32_TENSOR_OP_MATH 0
 19#define CUDA_R_16F  HIPBLAS_R_16F
 20#define CUDA_R_16BF HIPBLAS_R_16B
 21#define CUDA_R_32F  HIPBLAS_R_32F
 22#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
 23#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER
 24#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT
 25#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
 26#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
 27#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
 28#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
 29#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
 30#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
 31#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
 32#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
 33#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
 34#define __all_sync(mask, var) __all(var)
 35#define __any_sync(mask, var) __any(var)
 36#define cublasStrsmBatched hipblasStrsmBatched
 37#define cublasCreate hipblasCreate
 38#define cublasDestroy hipblasDestroy
 39#define cublasGemmEx hipblasGemmEx
 40#define cublasGemmBatchedEx hipblasGemmBatchedEx
 41#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
 42#define cublasHandle_t hipblasHandle_t
 43#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
 44#define cublasSetStream hipblasSetStream
 45#define cublasSgemm hipblasSgemm
 46#define cublasStatus_t hipblasStatus_t
 47#define cublasOperation_t hipblasOperation_t
 48#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
 49#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
 50#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
 51#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
 52#define cudaDeviceGetAttribute hipDeviceGetAttribute
 53#define cudaDeviceProp hipDeviceProp_t
 54#define cudaDeviceSynchronize hipDeviceSynchronize
 55#define cudaError_t hipError_t
 56#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
 57#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
 58#define cudaEventCreateWithFlags hipEventCreateWithFlags
 59#define cudaEventDisableTiming hipEventDisableTiming
 60#define cudaEventRecord hipEventRecord
 61#define cudaEventSynchronize hipEventSynchronize
 62#define cudaEvent_t hipEvent_t
 63#define cudaEventDestroy hipEventDestroy
 64#define cudaFree hipFree
 65#define cudaFreeHost hipHostFree
 66#define cudaGetDevice hipGetDevice
 67#define cudaGetDeviceCount hipGetDeviceCount
 68#define cudaGetDeviceProperties hipGetDeviceProperties
 69#define cudaGetErrorString hipGetErrorString
 70#define cudaGetLastError hipGetLastError
 71#define cudaHostRegister hipHostRegister
 72#define cudaHostRegisterPortable hipHostRegisterPortable
 73#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
 74#define cudaHostUnregister hipHostUnregister
 75#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
 76#define cudaLaunchHostFunc hipLaunchHostFunc
 77#define cudaMalloc hipMalloc
 78#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
 79#define cudaMallocManaged hipMallocManaged
 80#define cudaMemAdvise hipMemAdvise
 81#define cudaMemcpy hipMemcpy
 82#define cudaMemcpyAsync hipMemcpyAsync
 83#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
 84#define cudaMemcpy2DAsync hipMemcpy2DAsync
 85#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
 86#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
 87#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
 88#define cudaMemcpyKind hipMemcpyKind
 89#define cudaMemset hipMemset
 90#define cudaMemsetAsync hipMemsetAsync
 91#define cudaMemGetInfo hipMemGetInfo
 92#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
 93#define cudaSetDevice hipSetDevice
 94#define cuDeviceGet hipDeviceGet
 95#define CUdevice hipDevice_t
 96#define CUdeviceptr hipDeviceptr_t
 97#define cuMemUnmap hipMemUnmap
 98#define CUmemAccessDesc hipMemAccessDesc
 99#define cuMemAddressFree hipMemAddressFree
100#define cuMemRelease hipMemRelease
101#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
102#define cuMemCreate hipMemCreate
103#define cuMemAddressReserve hipMemAddressReserve
104#define cuMemMap hipMemMap
105#define cuMemSetAccess hipMemSetAccess
106#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
107#define CUmemAllocationProp hipMemAllocationProp
108#define cuDeviceGetAttribute hipDeviceGetAttribute
109#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
110#define cudaStreamDestroy hipStreamDestroy
111#define cudaStreamFireAndForget hipStreamFireAndForget
112#define cudaStreamNonBlocking hipStreamNonBlocking
113#define cudaStreamPerThread hipStreamPerThread
114#define cudaStreamSynchronize hipStreamSynchronize
115#define cudaStreamWaitEvent hipStreamWaitEvent
116#define cudaGraphExec_t hipGraphExec_t
117#define cudaGraphNode_t hipGraphNode_t
118#define cudaKernelNodeParams hipKernelNodeParams
119#define cudaKernelNodeParams hipKernelNodeParams
120#define cudaGraphExecDestroy hipGraphExecDestroy
121#define cudaGraphLaunch hipGraphLaunch
122#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
123#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
124#define cudaGraphNodeType hipGraphNodeType
125#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
126#define cudaGraphInstantiate hipGraphInstantiate
127#define cudaStreamEndCapture hipStreamEndCapture
128#define cudaGraphDestroy hipGraphDestroy
129#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
130#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
131#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
132#define cudaGraphNodeGetType hipGraphNodeGetType
133#define cudaGraphGetNodes hipGraphGetNodes
134#define cudaGraphExecUpdate hipGraphExecUpdate
135#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
136#define cudaStreamBeginCapture hipStreamBeginCapture
137#define cudaGraph_t hipGraph_t
138#define cudaStream_t hipStream_t
139#define cudaSuccess hipSuccess
140#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
141#define cudaFuncSetAttribute hipFuncSetAttribute
142#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
143#define __trap() do { abort(); __builtin_unreachable(); } while(0)
144#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
145#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
146#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
147#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
148#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
149#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
150#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
151#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
152#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
153
154#if HIP_VERSION >= 60500000
155#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
156#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
157#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
158#define cublasComputeType_t hipblasComputeType_t
159#define cudaDataType_t hipDataType
160#else
161#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
162#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
163#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
164#define cublasComputeType_t hipblasDatatype_t
165#define cudaDataType_t hipblasDatatype_t
166#endif // HIP_VERSION >= 6050000
167
168#if !defined(__HIP_PLATFORM_AMD__)
169#error "The HIP backend supports only AMD targets"
170#endif // !defined(__HIP_PLATFORM_AMD__)
171
172#define __CUDA_ARCH__ 1300
173
174#if defined(__gfx900__) || defined(__gfx906__)
175#define GCN5
176#endif // defined(__gfx900__) || defined(__gfx906__)
177
178#if defined(__gfx803__)
179#define GCN4
180#endif // defined(__gfx803__)
181
182#if defined(GCN5) || defined(GCN4)
183#define GCN
184#endif // defined(GCN5) || defined(GCN4)
185
186#if defined(__gfx942__)
187#define CDNA3
188#endif // defined(__gfx942__)
189
190#if defined(__gfx90a__)
191#define CDNA2
192#endif // defined(__gfx90a__)
193
194#if defined(__gfx908__)
195#define CDNA1
196#endif // defined(__gfx908__)
197
198#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
199#define CDNA // For the entire family
200#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
201
202#if defined(__GFX12__)
203#define RDNA4
204#endif // defined(__GFX12__)
205
206#if defined(__GFX11__)
207#define RDNA3
208#endif // defined(__GFX11__)
209
210#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
211    defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
212#define RDNA2
213#endif
214
215#if defined(__gfx1010__) || defined(__gfx1012__)
216#define RDNA1
217#endif // defined(__gfx1010__) || defined(__gfx1012__)
218
219#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
220#define RDNA // For the entire family
221#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
222
223#ifndef __has_builtin
224    #define __has_builtin(x) 0
225#endif
226
227typedef __hip_bfloat16 nv_bfloat16;
228typedef __hip_bfloat162 nv_bfloat162;
229
230typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
231typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
232static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
233    const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
234    const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
235#if __has_builtin(__builtin_elementwise_sub_sat)
236    const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
237    return reinterpret_cast<const int &>(c);
238#else
239    int8x4_t c;
240    int16_t tmp;
241#pragma unroll
242    for (int i = 0; i < 4; i++) {
243        tmp = va[i] - vb[i];
244        if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
245        if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
246        c[i] = tmp;
247    }
248    return reinterpret_cast<int &>(c);
249#endif // __has_builtin(__builtin_elementwise_sub_sat)
250}
251
252static __device__ __forceinline__ int __vsub4(const int a, const int b) {
253    return __vsubss4(a, b);
254}
255
256static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
257    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
258    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
259    unsigned int c;
260    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
261#pragma unroll
262    for (int i = 0; i < 4; ++i) {
263        vc[i] = va[i] == vb[i] ? 0xff : 0x00;
264    }
265    return c;
266}
267
268static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
269    const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
270    const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
271    unsigned int c;
272    uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
273#pragma unroll
274    for (int i = 0; i < 4; ++i) {
275        vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
276    }
277    return c;
278}