1#include "set-rows.cuh"
  2#include "cpy-utils.cuh"
  3
  4typedef void (*set_rows_kernel_t)(const char * src, char * dst);
  5
  6// Generic quantized set_rows kernel template
  7template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
  8static __global__ void k_set_rows_quant(const float * __restrict__ src0,
  9                                        const idx_t * __restrict__ src1,
 10                                        block_type * __restrict__ dst,
 11                                        const int64_t ne_total,
 12                                        const int64_t ne10,
 13                                        const int64_t ne11,
 14                                        const int64_t ne12,
 15                                        const int64_t ne13,
 16                                        const int64_t s01,
 17                                        const int64_t s02,
 18                                        const int64_t s03,
 19                                        const int64_t s10,
 20                                        const int64_t s11,
 21                                        const int64_t s12,
 22                                        const int64_t s1,
 23                                        const int64_t s2,
 24                                        const int64_t s3,
 25                                        const uint3   ne00,
 26                                        const uint3   ne01,
 27                                        const uint3   ne02,
 28                                        const uint3   ne11_fd,
 29                                        const uint3   ne12_fd) {
 30    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
 31
 32    if (i >= ne_total) {
 33        return;
 34    }
 35
 36    const int64_t i_base = i * qk;
 37    uint32_t      tmp    = (uint32_t) i_base;
 38    uint2         div_mod;
 39
 40    div_mod           = fast_div_modulo(tmp, ne00);
 41    const int64_t i00 = div_mod.y;
 42    tmp               = div_mod.x;
 43
 44    div_mod           = fast_div_modulo(tmp, ne01);
 45    const int64_t i01 = div_mod.y;
 46    tmp               = div_mod.x;
 47
 48    div_mod           = fast_div_modulo(tmp, ne02);
 49    const int64_t i02 = div_mod.y;
 50    const int64_t i03 = div_mod.x;
 51
 52    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
 53    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
 54    const int64_t i10 = i01;
 55
 56    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
 57
 58    const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
 59    block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
 60
 61    const float * src_block = src0_row + i00;
 62    block_type * dst_block = dst_row_ptr + i00 / qk;
 63
 64    quantize_func(src_block, dst_block);
 65
 66    GGML_UNUSED(ne10);
 67    GGML_UNUSED(ne11);
 68    GGML_UNUSED(ne12);
 69    GGML_UNUSED(ne13);
 70}
 71
 72// Template dispatch function for quantized set_rows
 73template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
 74static void set_rows_cuda_quant(
 75        const float * src0_d, const idx_t * src1_d, block_type * dst_d,
 76        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 77        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
 78        const size_t nb01, const size_t nb02, const size_t nb03,
 79        const size_t nb10, const size_t nb11, const size_t nb12,
 80        const size_t nb1, const size_t nb2, const size_t nb3,
 81        cudaStream_t stream) {
 82
 83    GGML_ASSERT(ne00 % qk == 0);
 84    const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
 85    const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
 86    const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
 87    const dim3 grid_size(num_blocks);
 88
 89    const int64_t s01 = nb01/sizeof(float);
 90    const int64_t s02 = nb02/sizeof(float);
 91    const int64_t s03 = nb03/sizeof(float);
 92    const int64_t s10 = nb10/sizeof(idx_t);
 93    const int64_t s11 = nb11/sizeof(idx_t);
 94    const int64_t s12 = nb12/sizeof(idx_t);
 95    const int64_t s1  = nb1;
 96    const int64_t s2  = nb2;
 97    const int64_t s3  = nb3;
 98
 99    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
100        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
101        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
102        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
103        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
104        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
105
106        k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
107            src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
108            ne01_fd, ne02_fd, ne11_fd, ne12_fd);
109    }
110}
111
112template <typename src_t, typename idx_t, typename dst_t>
113static __global__ void k_set_rows(const src_t * __restrict__ src0,
114                                  const idx_t * __restrict__ src1,
115                                  dst_t * __restrict__ dst,
116                                  const int64_t ne_total,
117                                  const int64_t ne10,
118                                  const int64_t ne11,
119                                  const int64_t ne12,
120                                  const int64_t ne13,
121                                  const int64_t s01,
122                                  const int64_t s02,
123                                  const int64_t s03,
124                                  const int64_t s10,
125                                  const int64_t s11,
126                                  const int64_t s12,
127                                  const int64_t s1,
128                                  const int64_t s2,
129                                  const int64_t s3,
130                                  const uint3   ne00,
131                                  const uint3   ne01,
132                                  const uint3   ne02,
133                                  const uint3   ne11_fd,
134                                  const uint3   ne12_fd) {
135    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
136
137    if (i >= ne_total) {
138        return;
139    }
140
141    uint32_t tmp = (uint32_t) i;
142    uint2    div_mod;
143
144    div_mod           = fast_div_modulo(tmp, ne00);
145    const int64_t i00 = div_mod.y;
146    tmp               = div_mod.x;
147
148    div_mod           = fast_div_modulo(tmp, ne01);
149    const int64_t i01 = div_mod.y;
150    tmp               = div_mod.x;
151
152    div_mod           = fast_div_modulo(tmp, ne02);
153    const int64_t i02 = div_mod.y;
154    const int64_t i03 = div_mod.x;
155
156    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
157    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
158    const int64_t i10 = i01;
159
160    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
161
162    const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
163    dst_t * dst_row_ptr    = dst + dst_row*s1 + i02*s2 + i03*s3;
164
165    dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
166
167    GGML_UNUSED(ne10);
168    GGML_UNUSED(ne11);
169    GGML_UNUSED(ne12);
170    GGML_UNUSED(ne13);
171}
172
173template<typename src_t, typename idx_t, typename dst_t>
174static void set_rows_cuda(
175        const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
176        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
177        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
178        const size_t nb01, const size_t nb02, const size_t nb03,
179        const size_t nb10, const size_t nb11, const size_t nb12,
180        const size_t nb1, const size_t nb2, const size_t nb3,
181        cudaStream_t stream) {
182
183    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
184    const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
185    const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
186    const dim3 grid_size(num_blocks);
187
188
189    const int64_t s01 = nb01/sizeof(src_t);
190    const int64_t s02 = nb02/sizeof(src_t);
191    const int64_t s03 = nb03/sizeof(src_t);
192    const int64_t s10 = nb10/sizeof(idx_t);
193    const int64_t s11 = nb11/sizeof(idx_t);
194    const int64_t s12 = nb12/sizeof(idx_t);
195    const int64_t s1  = nb1/sizeof(dst_t);
196    const int64_t s2  = nb2/sizeof(dst_t);
197    const int64_t s3  = nb3/sizeof(dst_t);
198
199    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
200        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
201        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
202        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
203        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
204        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
205
206        k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
207                                                         s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
208                                                         ne11_fd, ne12_fd);
209    }
210}
211
212template<typename src_t, typename idx_t>
213static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
214    const src_t * src0_d = (const src_t *)src0->data;
215    const idx_t * src1_d = (const idx_t *)src1->data;
216
217    GGML_TENSOR_BINARY_OP_LOCALS
218
219    cudaStream_t stream = ctx.stream();
220
221
222    if (dst->type == GGML_TYPE_F32) {
223        set_rows_cuda(
224            src0_d, src1_d, (float*)dst->data,
225            ne00, ne01, ne02, ne03,
226            ne10, ne11, ne12, ne13,
227            nb01, nb02, nb03,
228            nb10, nb11, nb12,
229            nb1, nb2, nb3,
230            stream
231        );
232    } else if (dst->type == GGML_TYPE_F16) {
233        set_rows_cuda(
234            src0_d, src1_d, (half*)dst->data,
235            ne00, ne01, ne02, ne03,
236            ne10, ne11, ne12, ne13,
237            nb01, nb02, nb03,
238            nb10, nb11, nb12,
239            nb1, nb2, nb3,
240            stream
241        );
242    } else if (dst->type == GGML_TYPE_BF16) {
243        set_rows_cuda(
244            src0_d, src1_d, (nv_bfloat16*)dst->data,
245            ne00, ne01, ne02, ne03,
246            ne10, ne11, ne12, ne13,
247            nb01, nb02, nb03,
248            nb10, nb11, nb12,
249            nb1, nb2, nb3,
250            stream
251        );
252    } else if (dst->type == GGML_TYPE_Q4_0) {
253        set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
254            src0_d, src1_d, (block_q4_0*)dst->data,
255            ne00, ne01, ne02, ne03,
256            ne10, ne11, ne12, ne13,
257            nb01, nb02, nb03,
258            nb10, nb11, nb12,
259            nb1, nb2, nb3,
260            stream
261        );
262    } else if (dst->type == GGML_TYPE_Q4_1) {
263        set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
264            src0_d, src1_d, (block_q4_1*)dst->data,
265            ne00, ne01, ne02, ne03,
266            ne10, ne11, ne12, ne13,
267            nb01, nb02, nb03,
268            nb10, nb11, nb12,
269            nb1, nb2, nb3,
270            stream
271        );
272    } else if (dst->type == GGML_TYPE_Q5_0) {
273        set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
274            src0_d, src1_d, (block_q5_0*)dst->data,
275            ne00, ne01, ne02, ne03,
276            ne10, ne11, ne12, ne13,
277            nb01, nb02, nb03,
278            nb10, nb11, nb12,
279            nb1, nb2, nb3,
280            stream
281        );
282    } else if (dst->type == GGML_TYPE_Q5_1) {
283        set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
284            src0_d, src1_d, (block_q5_1*)dst->data,
285            ne00, ne01, ne02, ne03,
286            ne10, ne11, ne12, ne13,
287            nb01, nb02, nb03,
288            nb10, nb11, nb12,
289            nb1, nb2, nb3,
290            stream
291        );
292    } else if (dst->type == GGML_TYPE_Q8_0) {
293        set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
294            src0_d, src1_d, (block_q8_0*)dst->data,
295            ne00, ne01, ne02, ne03,
296            ne10, ne11, ne12, ne13,
297            nb01, nb02, nb03,
298            nb10, nb11, nb12,
299            nb1, nb2, nb3,
300            stream
301        );
302    } else if (dst->type == GGML_TYPE_IQ4_NL) {
303        set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
304            src0_d, src1_d, (block_iq4_nl*)dst->data,
305            ne00, ne01, ne02, ne03,
306            ne10, ne11, ne12, ne13,
307            nb01, nb02, nb03,
308            nb10, nb11, nb12,
309            nb1, nb2, nb3,
310            stream
311        );
312    } else {
313        GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
314    }
315}
316
317
318void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
319    const ggml_tensor * src0 = dst->src[0];
320    const ggml_tensor * src1 = dst->src[1];
321
322    GGML_ASSERT(src0->type == GGML_TYPE_F32);
323    GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
324
325    if (src1->type == GGML_TYPE_I64) {
326        set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
327    } else {
328        set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
329    }
330}