1#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
2#define USE_CUB
3#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
4
5#ifdef USE_CUB
6#include <cub/cub.cuh>
7using namespace cub;
8#endif // USE_CUB
9
10#include "ssm-scan.cuh"
11
12// We would like to keep pragma unroll for cases where L_template is not 0,
13// so we suppress the clang transformation warning.
14#ifdef __clang__
15#pragma clang diagnostic push
16#pragma clang diagnostic ignored "-Wpass-failed"
17#endif // __clang__
18template <size_t splitD, size_t N, size_t L_template>
19__global__ void __launch_bounds__(splitD, 1)
20 ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
21 const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
22 const int32_t * __restrict__ src6, float * __restrict__ dst,
23 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
24 const int src2_nb1, const int src2_nb2, const int src3_nb1,
25 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
26 const int64_t s_off, const int64_t d_inner, const int64_t L_param)
27{
28 const size_t L = L_template == 0 ? L_param : L_template;
29 const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
30 const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
31 const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
32 const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
33 const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
34 const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
35 float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
36 float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
37
38 const int stride_x = src1_nb2 / sizeof(float);
39 const int stride_dt = src2_nb1 / sizeof(float);
40 const int stride_B = src4_nb2 / sizeof(float);
41 const int stride_C = src5_nb2 / sizeof(float);
42 const int stride_y = d_inner;
43
44 float regA[N];
45 float regs0[N];
46
47 __shared__ float smemB[N];
48 __shared__ float smemC[N];
49
50#ifdef USE_CUB
51 using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
52 using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
53
54 union CubTempStorage {
55 typename BlockLoad::TempStorage load_temp;
56 typename BlockStore::TempStorage store_temp;
57 };
58 __shared__ CubTempStorage cub_temp_storage;
59
60 BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
61 BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
62#else
63 const int stride_s0 = src0_nb2 / sizeof(float);
64 const int stride_A = src3_nb1 / sizeof(float);
65#pragma unroll
66 for (size_t n = 0; n < N; ++n)
67 {
68 regA[n] = A_block[threadIdx.x * stride_A + n];
69 regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
70 }
71#endif
72
73#pragma unroll
74 for (size_t i = 0; i < L; i++)
75 {
76 if (threadIdx.x < N)
77 {
78 smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
79 smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
80 }
81 __syncthreads();
82
83 float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
84 if (dt_soft_plus <= 20.0f)
85 {
86 dt_soft_plus = log1pf(expf(dt_soft_plus));
87 }
88 float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
89
90 float sumf = 0.0f;
91#pragma unroll
92 for (size_t n = 0; n < N; n++)
93 {
94 float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
95 sumf += state * smemC[n];
96 regs0[n] = state;
97 }
98 y_block[i * stride_y + threadIdx.x] = sumf;
99 }
100
101#ifdef USE_CUB
102 BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
103#else
104 const int stride_s = stride_s0;
105#pragma unroll
106 for (size_t n = 0; n < N; ++n)
107 {
108 s_block[threadIdx.x * stride_s + n] = regs0[n];
109 }
110#endif
111}
112#ifdef __clang__
113#pragma clang diagnostic pop
114#endif // __clang__
115
116// assumes as many threads as d_state
117template <int c_factor, int d_state>
118__global__ void __launch_bounds__(d_state, 1)
119 ssm_scan_f32_group(
120 const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
121 const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
122 const int32_t * __restrict__ src6, float * __restrict__ dst,
123 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
124 const int src2_nb1, const int src2_nb2, const int src3_nb1,
125 const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
126 const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
127
128 const int warp = threadIdx.x / WARP_SIZE;
129 const int lane = threadIdx.x % WARP_SIZE;
130 const int warp_idx = blockIdx.x * c_factor + warp;
131
132 const int head_idx = warp_idx / d_head;
133 const int head_off = (warp_idx % d_head) * sizeof(float);
134 const int seq_idx = blockIdx.y;
135
136 const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
137
138 // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
139 const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
140 const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
141 const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
142 const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
143 const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
144 const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
145 float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
146 float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
147
148 // strides across n_seq_tokens
149 const int stride_x = src1_nb2 / sizeof(float);
150 const int stride_dt = src2_nb1 / sizeof(float);
151 const int stride_B = src4_nb2 / sizeof(float);
152 const int stride_C = src5_nb2 / sizeof(float);
153 const int stride_y = n_head * d_head;
154
155 float state[c_factor];
156 float state_sum = 0.0f;
157
158#pragma unroll
159 for (int j = 0; j < c_factor; j++) {
160 state[j] = s0_warp[WARP_SIZE * j + lane];
161 }
162
163 for (int64_t i = 0; i < n_tok; i++) {
164 // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
165 // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
166 const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
167
168 state_sum = 0.0f;
169 const float dA = expf(dt_soft_plus * A_warp[0]);
170 const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
171#pragma unroll
172 for (int j = 0; j < c_factor; j++) {
173 const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
174 const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
175 state[j] = (state[j] * dA) + (B_val * x_dt);
176 state_sum += state[j] * C_val;
177 }
178
179 // parallel accumulation for output
180 state_sum = warp_reduce_sum(state_sum);
181
182 if (lane == 0) {
183 y_warp[i * stride_y] = state_sum;
184 }
185 }
186
187 // write back the state
188#pragma unroll
189 for (int j = 0; j < c_factor; j++) {
190 s_warp[WARP_SIZE * j + lane] = state[j];
191 }
192}
193
194static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
195 const float * src4, const float * src5, const int32_t * src6, float * dst,
196 const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
197 const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
198 const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
199 const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
200 cudaStream_t stream) {
201 // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
202 if (src3_nb1 == sizeof(float)) {
203 // Mamba-2
204 if (d_state == 128) {
205 constexpr int threads = 128;
206 constexpr int num_warps = threads/WARP_SIZE;
207
208 const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
209 ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
210 src0, src1, src2, src3, src4, src5, src6, dst,
211 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
212 src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
213 } else if (d_state == 256) { // Falcon-H1
214 constexpr int threads = 256;
215 constexpr int num_warps = threads/WARP_SIZE;
216
217 const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
218 ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
219 src0, src1, src2, src3, src4, src5, src6, dst,
220 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
221 src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
222 } else {
223 GGML_ABORT("doesn't support d_state!=(128 or 256).");
224 }
225 } else {
226 // Mamba-1
227 constexpr int threads = 128;
228 GGML_ASSERT(n_head % threads == 0);
229 GGML_ASSERT(head_dim == 1);
230 GGML_ASSERT(n_group == 1);
231 const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
232 const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
233 if (d_state == 16) {
234 switch (n_tok)
235 {
236 case 1:
237 ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
238 src0, src1, src2, src3, src4, src5, src6, dst,
239 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
240 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
241 break;
242 case 2:
243 ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
244 src0, src1, src2, src3, src4, src5, src6, dst,
245 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
246 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
247 break;
248 case 3:
249 ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
250 src0, src1, src2, src3, src4, src5, src6, dst,
251 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
252 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
253 break;
254 case 4:
255 ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
256 src0, src1, src2, src3, src4, src5, src6, dst,
257 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
258 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
259 break;
260 case 5:
261 ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
262 src0, src1, src2, src3, src4, src5, src6, dst,
263 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
264 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
265 break;
266 case 6:
267 ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
268 src0, src1, src2, src3, src4, src5, src6, dst,
269 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
270 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
271 break;
272 case 7:
273 ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
274 src0, src1, src2, src3, src4, src5, src6, dst,
275 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
276 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
277 break;
278 case 8:
279 ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
280 src0, src1, src2, src3, src4, src5, src6, dst,
281 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
282 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
283 break;
284 default:
285 ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
286 src0, src1, src2, src3, src4, src5, src6, dst,
287 src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
288 src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
289 break;
290 }
291 } else {
292 GGML_ABORT("doesn't support d_state!=16.");
293 }
294 }
295}
296
297void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
298 const struct ggml_tensor * src0 = dst->src[0]; // s
299 const struct ggml_tensor * src1 = dst->src[1]; // x
300 const struct ggml_tensor * src2 = dst->src[2]; // dt
301 const struct ggml_tensor * src3 = dst->src[3]; // A
302 const struct ggml_tensor * src4 = dst->src[4]; // B
303 const struct ggml_tensor * src5 = dst->src[5]; // C
304 const struct ggml_tensor * src6 = dst->src[6]; // ids
305
306 const int64_t nc = src0->ne[0]; // d_state
307 const int64_t nr = src0->ne[1]; // head_dim or 1
308 const int64_t nh = src1->ne[1]; // n_head
309 const int64_t ng = src4->ne[1]; // n_group
310 const int64_t n_t = src1->ne[2]; // number of tokens per sequence
311 const int64_t n_s = src1->ne[3]; // number of sequences in the batch
312
313 const int64_t s_off = ggml_nelements(src1) * sizeof(float);
314
315 GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
316 GGML_ASSERT(src0->nb[0] == sizeof(float));
317 GGML_ASSERT(src1->nb[0] == sizeof(float));
318 GGML_ASSERT(src2->nb[0] == sizeof(float));
319 GGML_ASSERT(src3->nb[0] == sizeof(float));
320 GGML_ASSERT(src4->nb[0] == sizeof(float));
321 GGML_ASSERT(src5->nb[0] == sizeof(float));
322 GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
323
324 const float * src0_d = (const float *) src0->data;
325 const float * src1_d = (const float *) src1->data;
326 const float * src2_d = (const float *) src2->data;
327 const float * src3_d = (const float *) src3->data;
328 const float * src4_d = (const float *) src4->data;
329 const float * src5_d = (const float *) src5->data;
330 const int32_t * src6_d = (const int32_t *) src6->data;
331 float * dst_d = (float *) dst->data;
332 cudaStream_t stream = ctx.stream();
333
334 GGML_ASSERT(src0->type == GGML_TYPE_F32);
335 GGML_ASSERT(src6->type == GGML_TYPE_I32);
336 GGML_ASSERT(dst->type == GGML_TYPE_F32);
337
338 ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
339 src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
340 src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
341 s_off, nc, nr, nh, ng, n_t, n_s, stream);
342}