1#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
  2#define USE_CUB
  3#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
  4
  5#ifdef USE_CUB
  6#include <cub/cub.cuh>
  7using namespace cub;
  8#endif // USE_CUB
  9
 10#include "ssm-scan.cuh"
 11
 12// We would like to keep pragma unroll for cases where L_template is not 0,
 13// so we suppress the clang transformation warning.
 14#ifdef __clang__
 15#pragma clang diagnostic push
 16#pragma clang diagnostic ignored "-Wpass-failed"
 17#endif // __clang__
 18template <size_t splitD, size_t N, size_t L_template>
 19__global__ void __launch_bounds__(splitD, 1)
 20    ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
 21                 const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
 22                 const int32_t * __restrict__ src6, float * __restrict__ dst,
 23                 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
 24                 const int src2_nb1, const int src2_nb2, const int src3_nb1,
 25                 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
 26                 const int64_t s_off, const int64_t d_inner, const int64_t L_param)
 27{
 28    const size_t L = L_template == 0 ? L_param : L_template;
 29    const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
 30    const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
 31    const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
 32    const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
 33    const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
 34    const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
 35    float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
 36    float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
 37
 38    const int stride_x = src1_nb2 / sizeof(float);
 39    const int stride_dt = src2_nb1 / sizeof(float);
 40    const int stride_B = src4_nb2 / sizeof(float);
 41    const int stride_C = src5_nb2 / sizeof(float);
 42    const int stride_y = d_inner;
 43
 44    float regA[N];
 45    float regs0[N];
 46
 47    __shared__ float smemB[N];
 48    __shared__ float smemC[N];
 49
 50#ifdef USE_CUB
 51    using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
 52    using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
 53
 54    union CubTempStorage {
 55        typename BlockLoad::TempStorage load_temp;
 56        typename BlockStore::TempStorage store_temp;
 57    };
 58    __shared__ CubTempStorage cub_temp_storage;
 59
 60    BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
 61    BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
 62#else
 63    const int stride_s0 = src0_nb2 / sizeof(float);
 64    const int stride_A = src3_nb1 / sizeof(float);
 65#pragma unroll
 66    for (size_t n = 0; n < N; ++n)
 67    {
 68        regA[n] = A_block[threadIdx.x * stride_A + n];
 69        regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
 70    }
 71#endif
 72
 73#pragma unroll
 74    for (size_t i = 0; i < L; i++)
 75    {
 76        if (threadIdx.x < N)
 77        {
 78            smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
 79            smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
 80        }
 81        __syncthreads();
 82
 83        float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
 84        if (dt_soft_plus <= 20.0f)
 85        {
 86            dt_soft_plus = log1pf(expf(dt_soft_plus));
 87        }
 88        float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
 89
 90        float sumf = 0.0f;
 91#pragma unroll
 92        for (size_t n = 0; n < N; n++)
 93        {
 94            float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
 95            sumf += state * smemC[n];
 96            regs0[n] = state;
 97        }
 98        y_block[i * stride_y + threadIdx.x] = sumf;
 99    }
100
101#ifdef USE_CUB
102    BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
103#else
104    const int stride_s = stride_s0;
105#pragma unroll
106    for (size_t n = 0; n < N; ++n)
107    {
108        s_block[threadIdx.x * stride_s + n] = regs0[n];
109    }
110#endif
111}
112#ifdef __clang__
113#pragma clang diagnostic pop
114#endif // __clang__
115
116// assumes as many threads as d_state
117template <int c_factor, int d_state>
118__global__ void __launch_bounds__(d_state, 1)
119    ssm_scan_f32_group(
120        const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
121        const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
122        const int32_t * __restrict__ src6, float * __restrict__ dst,
123        const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
124        const int src2_nb1, const int src2_nb2, const int src3_nb1,
125        const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
126        const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
127
128    const int warp     = threadIdx.x / WARP_SIZE;
129    const int lane     = threadIdx.x % WARP_SIZE;
130    const int warp_idx = blockIdx.x  * c_factor + warp;
131
132    const int head_idx =  warp_idx / d_head;
133    const int head_off = (warp_idx % d_head) * sizeof(float);
134    const int seq_idx  = blockIdx.y;
135
136    const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
137
138    // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
139    const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
140    const float * x_warp  = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
141    const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
142    const float * A_warp  = (const float *) ((const char *) src3 + head_idx * src3_nb1);
143    const float * B_warp  = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
144    const float * C_warp  = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
145    float *       y_warp  = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
146    float *       s_warp  = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
147
148    // strides across n_seq_tokens
149    const int stride_x  = src1_nb2 / sizeof(float);
150    const int stride_dt = src2_nb1 / sizeof(float);
151    const int stride_B  = src4_nb2 / sizeof(float);
152    const int stride_C  = src5_nb2 / sizeof(float);
153    const int stride_y  = n_head * d_head;
154
155    float state[c_factor];
156    float state_sum = 0.0f;
157
158#pragma unroll
159    for (int j = 0; j < c_factor; j++) {
160        state[j] = s0_warp[WARP_SIZE * j + lane];
161    }
162
163    for (int64_t i = 0; i < n_tok; i++) {
164        // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
165        // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
166        const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
167
168        state_sum = 0.0f;
169        const float dA   = expf(dt_soft_plus * A_warp[0]);
170        const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
171#pragma unroll
172        for (int j = 0; j < c_factor; j++) {
173            const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
174            const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
175            state[j] = (state[j] * dA) + (B_val * x_dt);
176            state_sum += state[j] * C_val;
177        }
178
179        // parallel accumulation for output
180        state_sum = warp_reduce_sum(state_sum);
181
182        if (lane == 0) {
183            y_warp[i * stride_y] = state_sum;
184        }
185    }
186
187    // write back the state
188#pragma unroll
189    for (int j = 0; j < c_factor; j++) {
190        s_warp[WARP_SIZE * j + lane] = state[j];
191    }
192}
193
194static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
195                              const float * src4, const float * src5, const int32_t * src6, float * dst,
196                              const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
197                              const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
198                              const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
199                              const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
200                              cudaStream_t stream) {
201    // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
202    if (src3_nb1 == sizeof(float)) {
203        // Mamba-2
204        if (d_state == 128) {
205            constexpr int threads   = 128;
206            constexpr int num_warps = threads/WARP_SIZE;
207
208            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
209            ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
210                    src0, src1, src2, src3, src4, src5, src6, dst,
211                    src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
212                    src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
213        } else if (d_state == 256) { // Falcon-H1
214            constexpr int threads   = 256;
215            constexpr int num_warps = threads/WARP_SIZE;
216
217            const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
218            ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
219                    src0, src1, src2, src3, src4, src5, src6, dst,
220                    src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
221                    src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
222        } else {
223            GGML_ABORT("doesn't support d_state!=(128 or 256).");
224        }
225    } else {
226        // Mamba-1
227        constexpr int threads = 128;
228        GGML_ASSERT(n_head % threads == 0);
229        GGML_ASSERT(head_dim == 1);
230        GGML_ASSERT(n_group == 1);
231        const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
232        const int  smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
233        if (d_state == 16) {
234            switch (n_tok)
235            {
236            case 1:
237                ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
238                    src0, src1, src2, src3, src4, src5, src6, dst,
239                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
240                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
241                break;
242            case 2:
243                ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
244                    src0, src1, src2, src3, src4, src5, src6, dst,
245                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
246                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
247                break;
248            case 3:
249                ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
250                    src0, src1, src2, src3, src4, src5, src6, dst,
251                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
252                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
253                break;
254            case 4:
255                ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
256                    src0, src1, src2, src3, src4, src5, src6, dst,
257                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
258                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
259                break;
260            case 5:
261                ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
262                    src0, src1, src2, src3, src4, src5, src6, dst,
263                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
264                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
265                break;
266            case 6:
267                ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
268                    src0, src1, src2, src3, src4, src5, src6, dst,
269                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
270                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
271                break;
272            case 7:
273                ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
274                    src0, src1, src2, src3, src4, src5, src6, dst,
275                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
276                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
277                break;
278            case 8:
279                ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
280                    src0, src1, src2, src3, src4, src5, src6, dst,
281                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
282                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
283                break;
284            default:
285                ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
286                    src0, src1, src2, src3, src4, src5, src6, dst,
287                src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
288                src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
289                break;
290            }
291        } else {
292            GGML_ABORT("doesn't support d_state!=16.");
293        }
294    }
295}
296
297void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
298    const struct ggml_tensor * src0 = dst->src[0];  // s
299    const struct ggml_tensor * src1 = dst->src[1];  // x
300    const struct ggml_tensor * src2 = dst->src[2];  // dt
301    const struct ggml_tensor * src3 = dst->src[3];  // A
302    const struct ggml_tensor * src4 = dst->src[4];  // B
303    const struct ggml_tensor * src5 = dst->src[5];  // C
304    const struct ggml_tensor * src6 = dst->src[6];  // ids
305
306    const int64_t nc  = src0->ne[0];  // d_state
307    const int64_t nr  = src0->ne[1];  // head_dim or 1
308    const int64_t nh  = src1->ne[1];  // n_head
309    const int64_t ng  = src4->ne[1];  // n_group
310    const int64_t n_t = src1->ne[2];  // number of tokens per sequence
311    const int64_t n_s = src1->ne[3];  // number of sequences in the batch
312
313    const int64_t s_off = ggml_nelements(src1) * sizeof(float);
314
315    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
316    GGML_ASSERT(src0->nb[0] == sizeof(float));
317    GGML_ASSERT(src1->nb[0] == sizeof(float));
318    GGML_ASSERT(src2->nb[0] == sizeof(float));
319    GGML_ASSERT(src3->nb[0] == sizeof(float));
320    GGML_ASSERT(src4->nb[0] == sizeof(float));
321    GGML_ASSERT(src5->nb[0] == sizeof(float));
322    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
323
324    const float * src0_d = (const float *) src0->data;
325    const float * src1_d = (const float *) src1->data;
326    const float * src2_d = (const float *) src2->data;
327    const float * src3_d = (const float *) src3->data;
328    const float * src4_d = (const float *) src4->data;
329    const float * src5_d = (const float *) src5->data;
330    const int32_t * src6_d = (const int32_t *) src6->data;
331    float *       dst_d  = (float *) dst->data;
332    cudaStream_t  stream = ctx.stream();
333
334    GGML_ASSERT(src0->type == GGML_TYPE_F32);
335    GGML_ASSERT(src6->type == GGML_TYPE_I32);
336    GGML_ASSERT(dst->type == GGML_TYPE_F32);
337
338    ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
339                      src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
340                      src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
341                      s_off, nc, nr, nh, ng, n_t, n_s, stream);
342}