1#include "common.cuh"
  2#include "ggml.h"
  3#include "softmax.cuh"
  4
  5#ifdef GGML_USE_HIP
  6#include <hip/hip_cooperative_groups.h>
  7#else
  8#include <cooperative_groups.h>
  9#include <cooperative_groups/reduce.h>
 10#endif // GGML_USE_HIP
 11
 12#include <cstdint>
 13#include <utility>
 14
 15template <typename T>
 16static __device__ __forceinline__ float t2f32(T val) {
 17    return (float) val;
 18}
 19
 20template <>
 21__device__ float __forceinline__ t2f32<half>(half val) {
 22    return __half2float(val);
 23}
 24
 25struct soft_max_params {
 26
 27    int64_t nheads;
 28    uint32_t n_head_log2;
 29    int64_t ncols;
 30    int64_t nrows_x;
 31    int64_t nrows_y;
 32    int64_t ne00;
 33    int64_t ne01;
 34    int64_t ne02;
 35    int64_t ne03;
 36    int64_t nb11;
 37    int64_t nb12;
 38    int64_t nb13;
 39
 40    int64_t ne12;
 41    int64_t ne13;
 42    float scale;
 43    float max_bias;
 44    float m0;
 45    float m1;
 46};
 47
 48// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
 49// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
 50#ifdef __clang__
 51#pragma clang diagnostic push
 52#pragma clang diagnostic ignored "-Wpass-failed"
 53#endif // __clang__
 54template <bool use_shared, int ncols_template, int block_size_template, typename T>
 55static __global__ void soft_max_f32(
 56        const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
 57    const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
 58
 59    const int tid  = threadIdx.x;
 60
 61    const int64_t i03 = blockIdx.z;
 62    const int64_t i02 = blockIdx.y;
 63    const int64_t i01 = blockIdx.x;
 64
 65    //TODO: noncontigous inputs/outputs
 66    const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
 67
 68    const int64_t i11 = i01;
 69    const int64_t i12 = i02 % p.ne12;
 70    const int64_t i13 = i03 % p.ne13;
 71
 72    x    += int64_t(rowx)*ncols;
 73    mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
 74    dst  += int64_t(rowx)*ncols;
 75
 76    const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 77
 78    const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
 79
 80    extern __shared__ float data_soft_max_f32[];
 81    float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
 82    // shared memory buffer to cache values between iterations:
 83    float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
 84
 85    float max_val = sinks ? sinks[i02] : -INFINITY;
 86
 87#pragma unroll
 88    for (int col0 = 0; col0 < ncols; col0 += block_size) {
 89        const int col = col0 + tid;
 90
 91        if (ncols_template == 0 && col >= ncols) {
 92            break;
 93        }
 94
 95        const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
 96
 97        vals[col] = val;
 98        max_val = max(max_val, val);
 99    }
100
101    // find the max value in the block
102    max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);
103
104    float tmp = 0.0f; // partial sum
105
106#pragma unroll
107    for (int col0 = 0; col0 < ncols; col0 += block_size) {
108        const int col = col0 + tid;
109
110        if (ncols_template == 0 && col >= ncols) {
111            break;
112        }
113
114        const float val = expf(vals[col] - max_val);
115        tmp += val;
116        vals[col] = val;
117    }
118
119    // find the sum of exps in the block
120    tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);
121
122    if (sinks) {
123        tmp += expf(sinks[i02] - max_val);
124    }
125
126    const float inv_sum = 1.0f / tmp;
127
128#pragma unroll
129    for (int col0 = 0; col0 < ncols; col0 += block_size) {
130        const int col = col0 + tid;
131
132        if (ncols_template == 0 && col >= ncols) {
133            return;
134        }
135
136        dst[col] = vals[col] * inv_sum;
137    }
138}
139
140// TODO: Template to allow keeping ncols in registers if they fit
141static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
142                                                                float * __restrict__ dst,
143                                                                float * __restrict__ tmp_maxs,
144                                                                float * __restrict__ tmp_sums,
145                                                                const soft_max_params p) {
146    namespace cg = cooperative_groups;
147
148    const cg::grid_group g = cg::this_grid();
149
150    const int tid               = threadIdx.x;
151    const int col_start         = blockIdx.x * blockDim.x + tid;
152    const int n_elem_per_thread = 4;
153
154    float     local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
155    float     local_max                     = -INFINITY;
156    const int step_size                     = gridDim.x * blockDim.x;
157    __shared__ float shared_vals[32];
158
159    // Compute thread-local max
160    for (int col = col_start; col < p.ncols;) {
161#pragma unroll
162        for (int i = 0; i < n_elem_per_thread; i++) {
163            const int idx = col + i * step_size;
164            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
165        }
166#pragma unroll
167        for (int i = 0; i < n_elem_per_thread; i++) {
168            local_max = fmaxf(local_max, local_vals[i]);
169        }
170        col += step_size * n_elem_per_thread;
171    }
172
173    // Compute CTA-level max
174    local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
175
176    // Store CTA-level max to GMEM
177    if (tid == 0) {
178        tmp_maxs[blockIdx.x] = local_max;
179    }
180    g.sync();
181
182    // Compute compute global max from CTA-level maxs
183    assert(gridDim.x < blockDim.x);  // currently we only support this case
184    if (tid < gridDim.x) {
185        local_max = tmp_maxs[tid];
186    } else {
187        local_max = -INFINITY;
188    }
189    local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
190
191    // Compute softmax dividends, accumulate divisor
192    float tmp_expf = 0.0f;
193    for (int col = col_start; col < p.ncols;) {
194#pragma unroll
195        for (int i = 0; i < n_elem_per_thread; i++) {
196            const int idx = col + i * step_size;
197            local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
198        }
199#pragma unroll
200        for (int i = 0; i < n_elem_per_thread; i++) {
201            const int idx = col + i * step_size;
202            if (idx < p.ncols) {
203                const float tmp = expf(local_vals[i] - local_max);
204                tmp_expf += tmp;
205                dst[idx] = tmp;
206            }
207        }
208        col += step_size * n_elem_per_thread;
209    }
210
211    // Reduce divisor within CTA
212    tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
213
214    // Store CTA-level sum to GMEM
215    if (tid == 0) {
216        tmp_sums[blockIdx.x] = tmp_expf;
217    }
218    g.sync();
219
220    // Compute global sum from CTA-level sums
221    if (tid < gridDim.x) {
222        tmp_expf = tmp_sums[tid];
223    } else {
224        tmp_expf = 0.0f;
225    }
226    tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
227
228    // Divide dividend by global sum + store data
229    for (int col = col_start; col < p.ncols;) {
230#pragma unroll
231        for (int i = 0; i < n_elem_per_thread; i++) {
232            const int idx = col + i * step_size;
233            local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
234        }
235#pragma unroll
236        for (int i = 0; i < n_elem_per_thread; i++) {
237            const int idx = col + i * step_size;
238            if (idx < p.ncols) {
239                dst[idx] = local_vals[i] / tmp_expf;
240            }
241        }
242        col += step_size * n_elem_per_thread;
243    }
244}
245
246#ifdef __clang__
247#pragma clang diagnostic pop
248#endif // __clang__
249
250static __global__ void soft_max_back_f32(
251        const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
252    const int tid  = threadIdx.x;
253    const int rowx = blockIdx.x;
254
255    grad += int64_t(rowx)*ncols;
256    dstf += int64_t(rowx)*ncols;
257    dst  += int64_t(rowx)*ncols;
258
259    float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
260
261    for (int col = tid; col < ncols; col += WARP_SIZE) {
262        dgf_dot += dstf[col]*grad[col];
263    }
264
265    dgf_dot = warp_reduce_sum(dgf_dot);
266
267    for (int col = tid; col < ncols; col += WARP_SIZE) {
268        dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
269    }
270}
271
272template<int... Ns, typename T>
273static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
274                             const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
275{
276    const int id       = ggml_cuda_get_device();
277    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
278
279    auto launch_kernel = [=](auto I) -> bool {
280        constexpr int ncols = decltype(I)::value;
281        constexpr int block = (ncols > 1024 ? 1024 : ncols);
282
283        if (p.ncols == ncols) {
284            CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
285            soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
286                (x, mask, sinks, dst, p);
287            return true;
288        }
289        return false;
290    };
291
292    // unary fold over launch_kernel
293    if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
294        return;
295    }
296
297    //default case
298    CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
299    soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
300}
301
302__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
303                                                     float * __restrict__ dst,
304                                                     float * __restrict__ tmp_maxs,
305                                                     float * __restrict__ tmp_sums,
306                                                     const soft_max_params p)
307// We loop over all instead of parallelizing across gridDim.y as cooperative groups
308// currently only support synchronizing the complete grid if not launched as a cluster group
309// (which requires CC > 9.0)
310// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
311// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
312{
313    for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
314        soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
315                                                 tmp_sums, p);
316    }
317}
318
319template <typename T>
320static void soft_max_f32_cuda(const float *                                x,
321                              const T *                                    mask,
322                              const float *                                sinks,
323                              float *                                      dst,
324                              const soft_max_params &                      params,
325                              cudaStream_t                                 stream,
326                              [[maybe_unused]] ggml_backend_cuda_context & ctx) {
327    int nth = WARP_SIZE;
328    const int64_t ncols_x = params.ncols;
329
330    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
331    const dim3 block_dims(nth,     1, 1);
332    const dim3 block_nums(params.ne01, params.ne02, params.ne03);
333    const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
334    static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
335
336
337    const int id       = ggml_cuda_get_device();
338    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
339
340
341    if (nbytes_shared <= smpbo) {
342        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
343    } else {
344        // Parallelize across SMs for top-p/dist-sampling
345        // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
346        // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
347        if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
348            ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
349            params.scale == 1.0f && params.max_bias == 0.0f) {
350            ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
351            ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
352
353            void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
354                                     (void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(&params) };
355            CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
356                                                   dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
357                                                   dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
358        } else {
359            const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
360            soft_max_f32<false, 0, 0>
361                <<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
362        }
363    }
364}
365
366static void soft_max_back_f32_cuda(
367        const float * grad, const float * dstf, float * dst,
368        const int ncols, const int nrows, const float scale, cudaStream_t stream) {
369    const dim3 block_dims(WARP_SIZE, 1, 1);
370    const dim3 block_nums(nrows,     1, 1);
371
372    soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
373}
374
375void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
376    const ggml_tensor * src0 = dst->src[0];
377    const ggml_tensor * src1 = dst->src[1];
378    const ggml_tensor * src2 = dst->src[2];
379
380    const float * src0_d = (const float *) src0->data;
381    const void  * src1_d = src1 ? (const void *) src1->data : nullptr;
382    const void  * src2_d = src2 ? (const void *) src2->data : nullptr;
383    float       *  dst_d = (float *) dst->data;
384
385    cudaStream_t stream = ctx.stream();
386
387    GGML_ASSERT(src0->type == GGML_TYPE_F32);
388    GGML_ASSERT( dst->type == GGML_TYPE_F32);
389
390    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
391
392    const int64_t nrows_x = ggml_nrows(src0);
393    const int64_t nrows_y = src0->ne[1];
394
395    const int64_t ne00 = src0->ne[0];
396
397    float scale    = 1.0f;
398    float max_bias = 0.0f;
399
400    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
401    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
402
403    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
404
405    const int64_t nb11 = src1 ? src1->nb[1] : 1;
406    const int64_t nb12 = src1 ? src1->nb[2] : 1;
407    const int64_t nb13 = src1 ? src1->nb[3] : 1;
408
409    const int64_t ne12 = src1 ? src1->ne[2] : 1;
410    const int64_t ne13 = src1 ? src1->ne[3] : 1;
411
412    const uint32_t n_head      = src0->ne[2];
413    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
414
415    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
416    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
417
418
419    soft_max_params params = {};
420    params.nheads = src0->ne[2];
421    params.n_head_log2 = n_head_log2;
422    params.ncols = ne00;
423    params.nrows_x = nrows_x;
424    params.nrows_y = nrows_y;
425    params.ne00 = src0->ne[0];
426    params.ne01 = src0->ne[1];
427    params.ne02 = src0->ne[2];
428    params.ne03 = src0->ne[3];
429    params.nb11 = nb11;
430    params.nb12 = nb12;
431    params.nb13 = nb13;
432    params.ne12 = ne12;
433    params.ne13 = ne13;
434    params.scale = scale;
435    params.max_bias = max_bias;
436    params.m0 = m0;
437    params.m1 = m1;
438
439    if (use_f16) {
440        soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
441    } else {
442        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
443    }
444}
445
446void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
447    const ggml_tensor * src0 = dst->src[0]; // grad
448    const ggml_tensor * src1 = dst->src[1]; // forward pass output
449
450    const float * src0_d = (const float *) src0->data;
451    const float * src1_d = (const float *) src1->data;
452    float       * dst_d  = (float       *) dst->data;
453
454    cudaStream_t stream = ctx.stream();
455
456    GGML_ASSERT(src0->type == GGML_TYPE_F32);
457    GGML_ASSERT(src1->type == GGML_TYPE_F32);
458    GGML_ASSERT( dst->type == GGML_TYPE_F32);
459
460    const int64_t ncols = src0->ne[0];
461    const int64_t nrows = ggml_nrows(src0);
462
463    float scale    = 1.0f;
464    float max_bias = 0.0f;
465
466    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
467    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
468
469    GGML_ASSERT(max_bias == 0.0f);
470
471    soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
472}