1#include <algorithm>
  2#include "cumsum.cuh"
  3#include "convert.cuh"
  4#include "ggml-cuda/common.cuh"
  5#include "ggml.h"
  6
  7#ifdef GGML_CUDA_USE_CUB
  8#   include <cub/cub.cuh>
  9#endif // GGML_CUDA_USE_CUB
 10
 11template<typename T, int BLOCK_SIZE>
 12static __global__ void cumsum_cub_kernel(
 13        const T * __restrict__ src,
 14        T * __restrict__ dst,
 15        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 16        const int64_t  s01, const int64_t  s02, const int64_t  s03,
 17        const int64_t   s1,  const int64_t   s2,  const int64_t   s3) {
 18#ifdef GGML_CUDA_USE_CUB
 19    using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
 20
 21    __shared__ typename BlockScanT::TempStorage temp_storage;
 22    __shared__ T block_carry;
 23
 24    const int tid = threadIdx.x;
 25    constexpr int UNROLL_FACTOR = 4;
 26    constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
 27
 28    const int64_t i1 = blockIdx.x;
 29    const int64_t i2 = blockIdx.y;
 30    const int64_t i3 = blockIdx.z;
 31
 32    if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
 33        return;
 34    }
 35
 36    const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
 37    T *       dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;
 38
 39    if (tid == 0) {
 40        block_carry = 0;
 41    }
 42    __syncthreads();
 43
 44    for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
 45        T items[UNROLL_FACTOR];
 46        T thread_sum = T(0);
 47
 48#pragma unroll
 49        for (int i = 0; i < UNROLL_FACTOR; i++) {
 50            int64_t idx = start + tid * UNROLL_FACTOR + i;
 51            T val = (idx < ne00) ? src_row[idx] : T(0);
 52            thread_sum += val;
 53            items[i] = thread_sum;
 54        }
 55
 56        // Block-wide scan on thread sums
 57        T thread_prefix;
 58        T block_total;
 59        BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
 60        __syncthreads();
 61
 62        // Add offset to each item and store
 63        T thread_offset = thread_prefix - thread_sum + block_carry;
 64#pragma unroll
 65        for (int i = 0; i < UNROLL_FACTOR; i++) {
 66            int64_t idx = start + tid * UNROLL_FACTOR + i;
 67            if (idx < ne00) {
 68                dst_row[idx] = items[i] + thread_offset;
 69            }
 70        }
 71
 72        __syncthreads();
 73
 74        // Update carry for next tile
 75        if (tid == 0) {
 76            block_carry += block_total;
 77        }
 78    }
 79#else
 80    NO_DEVICE_CODE;
 81#endif // GGML_CUDA_USE_CUB
 82}
 83
 84// Fallback kernel implementation
 85template<typename T>
 86static __global__ void cumsum_kernel(
 87        const T * src, T * dst,
 88        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 89        const int64_t  s00, const int64_t  s01, const int64_t  s02, const int64_t  s03,
 90        const int64_t   s0, const int64_t   s1, const int64_t   s2, const int64_t   s3) {
 91
 92    GGML_UNUSED_VARS(s00, s0);
 93
 94    const int tid = threadIdx.x;
 95    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 96    const int lane = tid % warp_size;
 97    const int warp = tid / warp_size;
 98    const int warps_per_block = blockDim.x / warp_size;
 99
100    extern __shared__ float smem[];
101    float *                 s_vals        = smem;
102    float *                 s_warp_sums   = smem + blockDim.x;
103    float *                 s_carry       = smem + blockDim.x + warps_per_block;
104    float *                 s_chunk_total = s_carry + 1;
105
106    // Initialize carry
107    if (tid == 0) {
108        *s_carry = 0.0f;
109    }
110    __syncthreads();
111
112    const int64_t i3 = blockIdx.z;
113    const int64_t i2 = blockIdx.y;
114    const int64_t i1 = blockIdx.x;
115    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
116        return;
117    }
118
119    const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
120    T       * dst_row = dst + i1 * s1  + i2 * s2  + i3 * s3;
121
122    // register blocking: process 4 elements per thread to hide latency
123    // and reduce synchronization overhead
124    constexpr int num_unroll = 4;
125    T             temp[num_unroll];
126
127    for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
128        int64_t idx = i + tid * num_unroll;
129
130        // thread local sequential scan
131        temp[0] = (idx < ne00 ? src_row[idx] : T(0));
132#pragma unroll
133        for (int64_t j = 1; j < num_unroll; j++) {
134            temp[j] = temp[j - 1];
135            if (idx + j < ne00) {
136                temp[j] += src_row[idx + j];
137            } else {
138                temp[j] += 0;
139            }
140        }
141
142        // last emenent is sum of all values assigned to thread
143        float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
144
145        // Warp inclusive scan
146        val = warp_prefix_inclusive_sum<T, warp_size>(val);
147        s_vals[tid] = val;
148
149        if (lane == warp_size - 1) {
150            s_warp_sums[warp] = val;
151        }
152        __syncthreads();
153
154        // Exclusive scan of warp sums (warp 0 only)
155        if (warp == 0) {
156            float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
157            float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
158            if (tid < warps_per_block) {
159                s_warp_sums[tid] = inc - w;   // exclusive sum
160            }
161            if (tid == warps_per_block - 1) {
162                *s_chunk_total = inc;          // total sum of this chunk
163            }
164        }
165        __syncthreads();
166
167        // write back results
168        float carry = *s_carry;
169        // calculate sum offset for this thread
170        float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
171
172#pragma unroll
173        for (int32_t j = 0; j < num_unroll; j++) {
174            if (idx + j < ne00) {
175                dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
176            }
177        }
178
179        __syncthreads();
180
181        // Update carry for next chunk
182        if (tid == 0) {
183            *s_carry += *s_chunk_total;
184        }
185    }
186}
187
188#ifdef GGML_CUDA_USE_CUB
189template <typename T>
190static void cumsum_cub(ggml_cuda_pool & pool,
191                       const T *        src,
192                       T *              dst,
193                       int64_t          ne,
194                       cudaStream_t     stream) {
195    size_t tmp_size = 0;
196
197    // Query how much temp storage CUDA UnBound (CUB) needs
198    cub::DeviceScan::InclusiveSum(nullptr,   // d_temp_storage (null = just query size)
199                                  tmp_size,  // reference to size (will be set by CUB)
200                                  src,       // input pointer
201                                  dst,       // output pointer
202                                  ne,        // number of elements
203                                  stream     // CUDA stream to use
204    );
205
206    ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
207
208    // Perform the inclusive scan
209    cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
210}
211#endif // GGML_CUDA_USE_CUB
212
213template<typename T>
214static void cumsum_cuda(
215        [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
216        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
217        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
218        const int64_t  nb0,  const int64_t nb1, const int64_t  nb2, const int64_t  nb3,
219        cudaStream_t stream) {
220
221    const size_t type_size = sizeof(T);
222    bool use_cub = false;
223#ifdef GGML_CUDA_USE_CUB
224    // Check if we can use CUB (data must be contiguous along innermost dimension)
225    const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
226
227    if (is_contiguous) {
228        use_cub = true;
229        const int64_t nrows = ne01 * ne02 * ne03;
230        // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
231        // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
232        if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
233            for (int i=0; i<nrows; i++) {
234                cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
235            }
236            return;
237        }
238    }
239#endif // GGML_CUDA_USE_CUB
240    dim3 grid_dims(ne01, ne02, ne03);
241    const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
242    const int warp_size = info.warp_size;
243    const int num_warps = (ne00 + warp_size - 1) / warp_size;
244    int block_size = num_warps * warp_size;
245    block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
246    dim3 block_dims(block_size, 1, 1);
247    const int warps_per_block = block_size / warp_size;
248    const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
249
250    if (use_cub && ne00 >= 1024) {
251        cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
252            src, dst,
253            ne00, ne01, ne02, ne03,
254            nb01 / type_size, nb02 / type_size, nb03 / type_size,
255            nb1 / type_size,  nb2 / type_size,  nb3 / type_size
256        );
257    } else {
258        cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
259            src, dst,
260            ne00, ne01, ne02, ne03,
261            nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
262            nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
263        );
264    }
265}
266
267void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
268    const ggml_tensor * src0 = dst->src[0];
269    cudaStream_t stream = ctx.stream();
270
271    GGML_ASSERT(src0->type == dst->type);
272    switch(src0->type) {
273        case GGML_TYPE_F32:
274            {
275                cumsum_cuda(
276                    ctx, (const float *)src0->data, (float *)dst->data,
277                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
278                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
279                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
280                    stream
281                );
282            } break;
283        // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
284        /*case GGML_TYPE_F16:
285            {
286                cumsum_cuda(
287                    (const half *)src0->data, (half *)dst->data,
288                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
289                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
290                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
291                    stream
292                );
293            } break;
294        case GGML_TYPE_BF16:
295            {
296                cumsum_cuda(
297                    (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
298                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
299                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
300                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
301                    stream
302                );
303            } break;*/
304        default:
305            GGML_ABORT("fatal error");
306    }
307}