1#include "cpy.cuh"
2#include "dequantize.cuh"
3#include "cpy-utils.cuh"
4#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
5#include "ggml-musa/mudnn.cuh"
6#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
7
8typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
9
10const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
11const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
12const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
13
14template <cpy_kernel_t cpy_1>
15static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
16 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
17 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
18 const int64_t nb12, const int64_t nb13) {
19 const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
20
21 if (i >= ne) {
22 return;
23 }
24
25 // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
26 // then combine those indices with the corresponding byte offsets to get the total offsets
27 const int64_t i03 = i/(ne00 * ne01 * ne02);
28 const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
29 const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
30 const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
31 const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
32
33 const int64_t i13 = i/(ne10 * ne11 * ne12);
34 const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
35 const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
36 const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
37 const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
38
39 cpy_1(cx + x_offset, cdst + dst_offset);
40}
41
42template <typename T>
43static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
44 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
45 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
46 const int64_t nb12, const int64_t nb13) {
47
48 const T* src = reinterpret_cast<const T*>(cx);
49 T* dst = reinterpret_cast<T*>(cdst);
50
51 const int64_t nmat = ne / (ne00 * ne01);
52 const int64_t n = ne00 * ne01;
53
54 const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
55 const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
56 const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
57 const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
58
59 __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
60
61#pragma unroll
62 for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
63
64 const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
65 if (imat >= nmat)
66 break;
67
68#pragma unroll
69 for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
70 if(x < ne01 && y + j < ne00){
71 const int row = threadIdx.y+j;
72 const int col = threadIdx.x * sizeof(float)/sizeof(T);
73 T *tile2 = reinterpret_cast<T*>(tile[row]);
74 tile2[col] = src[imat*n + (y+j)*ne01 + x];
75 }
76 }
77
78 __syncthreads();
79
80#pragma unroll
81 for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
82 if (ty + j < ne01 && tx < ne00) {
83 const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
84 const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
85 dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
86 }
87 }
88 }
89
90 GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
91 nb12, nb13);
92}
93
94static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
95 float * cdstf = (float *)(cdsti);
96
97#pragma unroll
98 for (int j = 0; j < QK8_0; j += 2) {
99 float2 dq;
100 dequantize_q8_0(cxi, 0, j, dq);
101 *(cdstf + j) = dq.x;
102 *(cdstf + j + 1) = dq.y;
103 }
104}
105
106template<dequantize_kernel_t dequant, int qk>
107static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
108 float * cdstf = (float *)(cdsti);
109
110#pragma unroll
111 for (int j = 0; j < qk/2; j++) {
112 float2 dq;
113 dequant(cxi, 0, j, dq);
114 *(cdstf + j) = dq.x;
115 *(cdstf + j + qk/2) = dq.y;
116 }
117}
118
119template <cpy_kernel_t cpy_blck, int qk>
120static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
121 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
122 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
123 const int64_t nb12, const int64_t nb13) {
124 const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
125
126 if (i >= ne) {
127 return;
128 }
129
130 const int64_t i03 = i/(ne00 * ne01 * ne02);
131 const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
132 const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
133 const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
134 const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
135
136 const int64_t i13 = i/(ne10 * ne11 * ne12);
137 const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
138 const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
139 const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
140 const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
141
142 cpy_blck(cx + x_offset, cdst + dst_offset);
143}
144
145template <cpy_kernel_t cpy_blck, int qk>
146static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
147 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
148 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
149 const int64_t nb12, const int64_t nb13) {
150 const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
151
152 if (i >= ne) {
153 return;
154 }
155
156 const int64_t i03 = i/(ne00 * ne01 * ne02);
157 const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
158 const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
159 const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
160 const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
161
162 const int64_t i13 = i/(ne10 * ne11 * ne12);
163 const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
164 const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
165 const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
166 const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
167
168 cpy_blck(cx + x_offset, cdst + dst_offset);
169}
170
171template<typename src_t, typename dst_t>
172static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
173 const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
174
175 if (i >= ne) {
176 return;
177 }
178
179 const src_t * x = (const src_t *) cx;
180 dst_t * dst = (dst_t *) cdst;
181
182 dst[i] = ggml_cuda_cast<dst_t>(x[i]);
183}
184
185template<typename src_t, typename dst_t>
186static void ggml_cpy_scalar_contiguous_cuda(
187 const char * cx, char * cdst, const int64_t ne,
188cudaStream_t stream) {
189
190 const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
191 GGML_ASSERT(num_blocks < UINT_MAX);
192 cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
193 (cx, cdst, ne);
194}
195
196template<typename src_t, typename dst_t, bool transposed = false>
197static void ggml_cpy_scalar_cuda(
198 const char * cx, char * cdst, const int64_t ne,
199 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
200 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
201
202 if (transposed) {
203 GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
204 int64_t ne00n, ne01n, ne02n;
205 if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
206 ne00n = ne00;
207 ne01n = ne01;
208 ne02n = ne02;
209 } else {
210 ne00n = ne00;
211 ne01n = ne01*ne02;
212 ne02n = 1;
213 }
214
215 int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
216 int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
217 int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
218 GGML_ASSERT(grid_x < UINT_MAX);
219 GGML_ASSERT(grid_y < USHRT_MAX);
220 GGML_ASSERT(grid_z < USHRT_MAX);
221 dim3 dimGrid(grid_x, grid_y, grid_z);
222 dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
223 cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
224 (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
225 } else {
226 const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
227 GGML_ASSERT(num_blocks < UINT_MAX);
228 cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
229 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
230 }
231}
232
233static void ggml_cpy_f32_q8_0_cuda(
234 const char * cx, char * cdst, const int64_t ne,
235 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
236 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
237
238 GGML_ASSERT(ne % QK8_0 == 0);
239 const int64_t num_blocks = ne / QK8_0;
240 GGML_ASSERT(num_blocks < UINT_MAX);
241 cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
242 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
243}
244
245static void ggml_cpy_q8_0_f32_cuda(
246 const char * cx, char * cdst, const int64_t ne,
247 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
248 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
249
250 const int64_t num_blocks = ne;
251 GGML_ASSERT(num_blocks < UINT_MAX);
252 cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
253 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
254}
255
256static void ggml_cpy_f32_q4_0_cuda(
257 const char * cx, char * cdst, const int64_t ne,
258 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
259 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
260
261 GGML_ASSERT(ne % QK4_0 == 0);
262 const int64_t num_blocks = ne / QK4_0;
263 GGML_ASSERT(num_blocks < UINT_MAX);
264 cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
265 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
266}
267
268static void ggml_cpy_q4_0_f32_cuda(
269 const char * cx, char * cdst, const int64_t ne,
270 const int64_t ne00, const int64_t ne01, const int64_t ne02,
271 const int64_t nb00, const int64_t nb01, const int64_t nb02,
272 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
273 const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
274 cudaStream_t stream) {
275 const int64_t num_blocks = ne;
276 GGML_ASSERT(num_blocks < UINT_MAX);
277 cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
278 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
279 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
280}
281
282static void ggml_cpy_f32_q4_1_cuda(
283 const char * cx, char * cdst, const int64_t ne,
284 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
285 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
286
287 GGML_ASSERT(ne % QK4_1 == 0);
288 const int64_t num_blocks = ne / QK4_1;
289 GGML_ASSERT(num_blocks < UINT_MAX);
290 cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
291 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
292}
293
294static void ggml_cpy_q4_1_f32_cuda(
295 const char * cx, char * cdst, const int64_t ne,
296 const int64_t ne00, const int64_t ne01, const int64_t ne02,
297 const int64_t nb00, const int64_t nb01, const int64_t nb02,
298 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
299 const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
300 cudaStream_t stream) {
301 const int64_t num_blocks = ne;
302 GGML_ASSERT(num_blocks < UINT_MAX);
303 cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
304 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
305 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
306}
307
308static void ggml_cpy_f32_q5_0_cuda(
309 const char * cx, char * cdst, const int64_t ne,
310 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
311 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
312
313 GGML_ASSERT(ne % QK5_0 == 0);
314 const int64_t num_blocks = ne / QK5_0;
315 GGML_ASSERT(num_blocks < UINT_MAX);
316 cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
317 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
318}
319
320static void ggml_cpy_q5_0_f32_cuda(
321 const char * cx, char * cdst, const int64_t ne,
322 const int64_t ne00, const int64_t ne01, const int64_t ne02,
323 const int64_t nb00, const int64_t nb01, const int64_t nb02,
324 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
325 const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
326 cudaStream_t stream) {
327 const int64_t num_blocks = ne;
328 GGML_ASSERT(num_blocks < UINT_MAX);
329 cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
330 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
331 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
332}
333
334static void ggml_cpy_f32_q5_1_cuda(
335 const char * cx, char * cdst, const int64_t ne,
336 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
337 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
338
339 GGML_ASSERT(ne % QK5_1 == 0);
340 const int64_t num_blocks = ne / QK5_1;
341 GGML_ASSERT(num_blocks < UINT_MAX);
342 cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
343 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
344}
345
346static void ggml_cpy_q5_1_f32_cuda(
347 const char * cx, char * cdst, const int64_t ne,
348 const int64_t ne00, const int64_t ne01, const int64_t ne02,
349 const int64_t nb00, const int64_t nb01, const int64_t nb02,
350 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
351 const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
352 cudaStream_t stream) {
353 const int64_t num_blocks = ne;
354 GGML_ASSERT(num_blocks < UINT_MAX);
355 cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
356 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
357 ne10, ne11, ne12, nb10, nb11, nb12, nb13);
358}
359
360static void ggml_cpy_f32_iq4_nl_cuda(
361 const char * cx, char * cdst, const int64_t ne,
362 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
363 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
364
365 GGML_ASSERT(ne % QK4_NL == 0);
366 const int64_t num_blocks = ne / QK4_NL;
367 GGML_ASSERT(num_blocks < UINT_MAX);
368 cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
369 (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
370}
371
372void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
373 const int64_t ne = ggml_nelements(src0);
374 GGML_ASSERT(ne == ggml_nelements(src1));
375
376 const int64_t ne00 = src0->ne[0];
377 const int64_t ne01 = src0->ne[1];
378 const int64_t ne02 = src0->ne[2];
379
380 //GGML_ASSERT(src0->ne[3] == 1);
381
382 const int64_t nb00 = src0->nb[0];
383 const int64_t nb01 = src0->nb[1];
384 const int64_t nb02 = src0->nb[2];
385 const int64_t nb03 = src0->nb[3];
386
387 const int64_t ne10 = src1->ne[0];
388 const int64_t ne11 = src1->ne[1];
389 const int64_t ne12 = src1->ne[2];
390
391 //GGML_ASSERT(src1->ne[3] == 1);
392
393 const int64_t nb10 = src1->nb[0];
394 const int64_t nb11 = src1->nb[1];
395 const int64_t nb12 = src1->nb[2];
396 const int64_t nb13 = src1->nb[3];
397
398 cudaStream_t main_stream = ctx.stream();
399
400 char * src0_ddc = (char *) src0->data;
401 char * src1_ddc = (char *) src1->data;
402
403 const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
404 const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
405 src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
406
407 if (src0->type == src1->type && contiguous_srcs) {
408 GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
409#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
410 if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
411 CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
412 } else
413#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
414 {
415 CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
416 }
417 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
418 if (can_be_transposed) {
419 ggml_cpy_scalar_cuda<float, float, true>
420 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
421 } else {
422 ggml_cpy_scalar_cuda<float, float>
423 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
424 }
425 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
426 if (contiguous_srcs) {
427 ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
428 (src0_ddc, src1_ddc, ne, main_stream);
429 } else {
430 ggml_cpy_scalar_cuda<float, nv_bfloat16>
431 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
432 }
433 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
434 if (contiguous_srcs) {
435 ggml_cpy_scalar_contiguous_cuda<float, half>
436 (src0_ddc, src1_ddc, ne, main_stream);
437 } else {
438 ggml_cpy_scalar_cuda<float, half>
439 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
440 }
441 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
442 ggml_cpy_f32_q8_0_cuda
443 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
444 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
445 ggml_cpy_q8_0_f32_cuda
446 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
447 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
448 ggml_cpy_f32_q4_0_cuda
449 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
450 } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
451 ggml_cpy_q4_0_f32_cuda
452 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
453 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
454 ggml_cpy_f32_q4_1_cuda
455 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
456 } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
457 ggml_cpy_q4_1_f32_cuda
458 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
459 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
460 ggml_cpy_f32_q5_0_cuda
461 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
462 } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
463 ggml_cpy_q5_0_f32_cuda
464 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
465 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
466 ggml_cpy_f32_iq4_nl_cuda
467 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
468 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
469 ggml_cpy_f32_q5_1_cuda
470 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
471 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
472 ggml_cpy_q5_1_f32_cuda
473 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
474 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
475 if (can_be_transposed) {
476 ggml_cpy_scalar_cuda<half, half, true>
477 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
478 } else {
479 ggml_cpy_scalar_cuda<half, half>
480 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
481 }
482 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
483 if (contiguous_srcs) {
484 ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
485 (src0_ddc, src1_ddc, ne, main_stream);
486 } else {
487 ggml_cpy_scalar_cuda<half, nv_bfloat16>
488 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
489 }
490 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
491 if (contiguous_srcs) {
492 ggml_cpy_scalar_contiguous_cuda<half, float>
493 (src0_ddc, src1_ddc, ne, main_stream);
494 } else {
495 ggml_cpy_scalar_cuda<half, float>
496 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
497 }
498 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
499 if (can_be_transposed) {
500 ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
501 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
502 } else {
503 ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
504 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
505 }
506 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
507 if (contiguous_srcs) {
508 ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
509 (src0_ddc, src1_ddc, ne, main_stream);
510 } else {
511 ggml_cpy_scalar_cuda<nv_bfloat16, half>
512 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
513 }
514 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
515 if (contiguous_srcs) {
516 ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
517 (src0_ddc, src1_ddc, ne, main_stream);
518 } else {
519 ggml_cpy_scalar_cuda<nv_bfloat16, float>
520 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
521 }
522 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
523 if (can_be_transposed) {
524 ggml_cpy_scalar_cuda<int32_t, int32_t, true>
525 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
526 } else {
527 ggml_cpy_scalar_cuda<int32_t, int32_t>
528 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
529 }
530 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
531 if (contiguous_srcs) {
532 ggml_cpy_scalar_contiguous_cuda<float, int32_t>
533 (src0_ddc, src1_ddc, ne, main_stream);
534 } else {
535 ggml_cpy_scalar_cuda<float, int32_t>
536 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
537 }
538 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
539 if (contiguous_srcs) {
540 ggml_cpy_scalar_contiguous_cuda<int32_t, float>
541 (src0_ddc, src1_ddc, ne, main_stream);
542 } else {
543 ggml_cpy_scalar_cuda<int32_t, float>
544 (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
545 }
546 } else {
547 GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
548 ggml_type_name(src0->type), ggml_type_name(src1->type));
549 }
550}
551
552void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
553 const ggml_tensor * src0 = dst->src[0];
554 ggml_cuda_cpy(ctx, src0, dst);
555}