1#pragma once
 2
 3#include "common.cuh"
 4#include "mmq.cuh"
 5
 6#include <cstdint>
 7
 8#define CUDA_QUANTIZE_BLOCK_SIZE     256
 9#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
10
11static_assert(MATRIX_ROW_PADDING %    CUDA_QUANTIZE_BLOCK_SIZE      == 0, "Risk of out-of-bounds access.");
12static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
13
14typedef void (*quantize_cuda_t)(
15        const float * x, const int32_t * ids, void * vy,
16        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
17        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
18
19void quantize_row_q8_1_cuda(
20        const float * x, const int32_t * ids, void * vy,
21        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
22        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
23
24void quantize_mmq_q8_1_cuda(
25        const float * x, const int32_t * ids, void * vy,
26        ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
27        int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
28
29void quantize_mmq_mxfp4_cuda(const float *   x,
30                             const int32_t * ids,
31                             void *          vy,
32                             ggml_type       type_src0,
33                             int64_t         ne00,
34                             int64_t         s01,
35                             int64_t         s02,
36                             int64_t         s03,
37                             int64_t         ne0,
38                             int64_t         ne1,
39                             int64_t         ne2,
40                             int64_t         ne3,
41                             cudaStream_t    stream);