1#pragma once
2
3#include <cuda_runtime.h>
4#include <cuda.h>
5#include <cublas_v2.h>
6#include <cuda_bf16.h>
7#include <cuda_fp16.h>
8
9#if CUDART_VERSION >= 12050
10#include <cuda_fp8.h>
11#endif // CUDART_VERSION >= 12050
12
13#if CUDART_VERSION >= 12080
14#include <cuda_fp4.h>
15#endif // CUDART_VERSION >= 12080
16
17#if CUDART_VERSION < 11020
18#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
19#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
20#define CUBLAS_COMPUTE_16F CUDA_R_16F
21#define CUBLAS_COMPUTE_32F CUDA_R_32F
22#define cublasComputeType_t cudaDataType_t
23#endif // CUDART_VERSION < 11020