diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda/convert.cuh')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-cuda/convert.cuh | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/convert.cuh b/llama.cpp/ggml/src/ggml-cuda/convert.cuh new file mode 100644 index 0000000..09f9a33 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-cuda/convert.cuh @@ -0,0 +1,56 @@ +#pragma once +#include "common.cuh" + +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 + +template<typename T> +using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream); + +typedef to_t_cuda_t<float> to_fp32_cuda_t; +typedef to_t_cuda_t<half> to_fp16_cuda_t; +typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t; + +to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type); + +to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type); + +to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type); + +// TODO more general support for non-contiguous inputs + +template<typename T> +using to_t_nc_cuda_t = void (*)(const void * x, T * y, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream); + +typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t; +typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t; +typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t; + +to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type); +to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type); +to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); + +template<typename dst_t, typename src_t> + __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { + if constexpr (std::is_same_v<dst_t, src_t>) { + return x; + } else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) { + return __float2bfloat16(float(x)); + } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) { + return __bfloat162float(x); + } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) { + return __float22half2_rn(x); + } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) { + // bypass compile error on cuda 12.0.1 +#ifdef GGML_USE_HIP + return __float22bfloat162_rn(x); +#else + return {x.x, x.y}; +#endif // GGML_USE_HIP + } else if constexpr(std::is_same_v<dst_t, int32_t>) { + return int32_t(x); + } else { + return float(x); + } +} |
