1#include <algorithm>
2#include "cumsum.cuh"
3#include "convert.cuh"
4#include "ggml-cuda/common.cuh"
5#include "ggml.h"
6
7#ifdef GGML_CUDA_USE_CUB
8# include <cub/cub.cuh>
9#endif // GGML_CUDA_USE_CUB
10
11template<typename T, int BLOCK_SIZE>
12static __global__ void cumsum_cub_kernel(
13 const T * __restrict__ src,
14 T * __restrict__ dst,
15 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
16 const int64_t s01, const int64_t s02, const int64_t s03,
17 const int64_t s1, const int64_t s2, const int64_t s3) {
18#ifdef GGML_CUDA_USE_CUB
19 using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
20
21 __shared__ typename BlockScanT::TempStorage temp_storage;
22 __shared__ T block_carry;
23
24 const int tid = threadIdx.x;
25 constexpr int UNROLL_FACTOR = 4;
26 constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
27
28 const int64_t i1 = blockIdx.x;
29 const int64_t i2 = blockIdx.y;
30 const int64_t i3 = blockIdx.z;
31
32 if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
33 return;
34 }
35
36 const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
37 T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
38
39 if (tid == 0) {
40 block_carry = 0;
41 }
42 __syncthreads();
43
44 for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
45 T items[UNROLL_FACTOR];
46 T thread_sum = T(0);
47
48#pragma unroll
49 for (int i = 0; i < UNROLL_FACTOR; i++) {
50 int64_t idx = start + tid * UNROLL_FACTOR + i;
51 T val = (idx < ne00) ? src_row[idx] : T(0);
52 thread_sum += val;
53 items[i] = thread_sum;
54 }
55
56 // Block-wide scan on thread sums
57 T thread_prefix;
58 T block_total;
59 BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
60 __syncthreads();
61
62 // Add offset to each item and store
63 T thread_offset = thread_prefix - thread_sum + block_carry;
64#pragma unroll
65 for (int i = 0; i < UNROLL_FACTOR; i++) {
66 int64_t idx = start + tid * UNROLL_FACTOR + i;
67 if (idx < ne00) {
68 dst_row[idx] = items[i] + thread_offset;
69 }
70 }
71
72 __syncthreads();
73
74 // Update carry for next tile
75 if (tid == 0) {
76 block_carry += block_total;
77 }
78 }
79#else
80 NO_DEVICE_CODE;
81#endif // GGML_CUDA_USE_CUB
82}
83
84// Fallback kernel implementation
85template<typename T>
86static __global__ void cumsum_kernel(
87 const T * src, T * dst,
88 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
89 const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
90 const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
91
92 GGML_UNUSED_VARS(s00, s0);
93
94 const int tid = threadIdx.x;
95 constexpr int warp_size = ggml_cuda_get_physical_warp_size();
96 const int lane = tid % warp_size;
97 const int warp = tid / warp_size;
98 const int warps_per_block = blockDim.x / warp_size;
99
100 extern __shared__ float smem[];
101 float * s_vals = smem;
102 float * s_warp_sums = smem + blockDim.x;
103 float * s_carry = smem + blockDim.x + warps_per_block;
104 float * s_chunk_total = s_carry + 1;
105
106 // Initialize carry
107 if (tid == 0) {
108 *s_carry = 0.0f;
109 }
110 __syncthreads();
111
112 const int64_t i3 = blockIdx.z;
113 const int64_t i2 = blockIdx.y;
114 const int64_t i1 = blockIdx.x;
115 if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
116 return;
117 }
118
119 const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
120 T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
121
122 // register blocking: process 4 elements per thread to hide latency
123 // and reduce synchronization overhead
124 constexpr int num_unroll = 4;
125 T temp[num_unroll];
126
127 for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
128 int64_t idx = i + tid * num_unroll;
129
130 // thread local sequential scan
131 temp[0] = (idx < ne00 ? src_row[idx] : T(0));
132#pragma unroll
133 for (int64_t j = 1; j < num_unroll; j++) {
134 temp[j] = temp[j - 1];
135 if (idx + j < ne00) {
136 temp[j] += src_row[idx + j];
137 } else {
138 temp[j] += 0;
139 }
140 }
141
142 // last emenent is sum of all values assigned to thread
143 float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
144
145 // Warp inclusive scan
146 val = warp_prefix_inclusive_sum<T, warp_size>(val);
147 s_vals[tid] = val;
148
149 if (lane == warp_size - 1) {
150 s_warp_sums[warp] = val;
151 }
152 __syncthreads();
153
154 // Exclusive scan of warp sums (warp 0 only)
155 if (warp == 0) {
156 float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
157 float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
158 if (tid < warps_per_block) {
159 s_warp_sums[tid] = inc - w; // exclusive sum
160 }
161 if (tid == warps_per_block - 1) {
162 *s_chunk_total = inc; // total sum of this chunk
163 }
164 }
165 __syncthreads();
166
167 // write back results
168 float carry = *s_carry;
169 // calculate sum offset for this thread
170 float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
171
172#pragma unroll
173 for (int32_t j = 0; j < num_unroll; j++) {
174 if (idx + j < ne00) {
175 dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
176 }
177 }
178
179 __syncthreads();
180
181 // Update carry for next chunk
182 if (tid == 0) {
183 *s_carry += *s_chunk_total;
184 }
185 }
186}
187
188#ifdef GGML_CUDA_USE_CUB
189template <typename T>
190static void cumsum_cub(ggml_cuda_pool & pool,
191 const T * src,
192 T * dst,
193 int64_t ne,
194 cudaStream_t stream) {
195 size_t tmp_size = 0;
196
197 // Query how much temp storage CUDA UnBound (CUB) needs
198 cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
199 tmp_size, // reference to size (will be set by CUB)
200 src, // input pointer
201 dst, // output pointer
202 ne, // number of elements
203 stream // CUDA stream to use
204 );
205
206 ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
207
208 // Perform the inclusive scan
209 cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
210}
211#endif // GGML_CUDA_USE_CUB
212
213template<typename T>
214static void cumsum_cuda(
215 [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
216 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
217 const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
218 const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
219 cudaStream_t stream) {
220
221 const size_t type_size = sizeof(T);
222 bool use_cub = false;
223#ifdef GGML_CUDA_USE_CUB
224 // Check if we can use CUB (data must be contiguous along innermost dimension)
225 const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
226
227 if (is_contiguous) {
228 use_cub = true;
229 const int64_t nrows = ne01 * ne02 * ne03;
230 // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
231 // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
232 if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
233 for (int i=0; i<nrows; i++) {
234 cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
235 }
236 return;
237 }
238 }
239#endif // GGML_CUDA_USE_CUB
240 dim3 grid_dims(ne01, ne02, ne03);
241 const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
242 const int warp_size = info.warp_size;
243 const int num_warps = (ne00 + warp_size - 1) / warp_size;
244 int block_size = num_warps * warp_size;
245 block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
246 dim3 block_dims(block_size, 1, 1);
247 const int warps_per_block = block_size / warp_size;
248 const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
249
250 if (use_cub && ne00 >= 1024) {
251 cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
252 src, dst,
253 ne00, ne01, ne02, ne03,
254 nb01 / type_size, nb02 / type_size, nb03 / type_size,
255 nb1 / type_size, nb2 / type_size, nb3 / type_size
256 );
257 } else {
258 cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
259 src, dst,
260 ne00, ne01, ne02, ne03,
261 nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
262 nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
263 );
264 }
265}
266
267void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268 const ggml_tensor * src0 = dst->src[0];
269 cudaStream_t stream = ctx.stream();
270
271 GGML_ASSERT(src0->type == dst->type);
272 switch(src0->type) {
273 case GGML_TYPE_F32:
274 {
275 cumsum_cuda(
276 ctx, (const float *)src0->data, (float *)dst->data,
277 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
278 src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
279 dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
280 stream
281 );
282 } break;
283 // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
284 /*case GGML_TYPE_F16:
285 {
286 cumsum_cuda(
287 (const half *)src0->data, (half *)dst->data,
288 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
289 src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
290 dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
291 stream
292 );
293 } break;
294 case GGML_TYPE_BF16:
295 {
296 cumsum_cuda(
297 (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
298 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
299 src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
300 dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
301 stream
302 );
303 } break;*/
304 default:
305 GGML_ABORT("fatal error");
306 }
307}