1#include "ssm-conv.cuh"
2
3template <size_t split_d_inner, size_t d_conv>
4static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
5 const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
6 float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
7 const int64_t n_t) {
8 GGML_UNUSED(src0_nb0);
9 const int tid = threadIdx.x;
10 const int bidx = blockIdx.x;
11 const int bidy = blockIdx.y;
12
13 const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
14 const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
15 float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
16
17 const int stride_x = src0_nb1 / sizeof(float);
18 const int stride_w = src1_nb1 / sizeof(float);
19 const int stride_y = dst_nb1 / sizeof(float);
20
21 float x[d_conv] = { 0.0f };
22 float w[d_conv] = { 0.0f };
23
24#pragma unroll
25 for (size_t j = 0; j < d_conv; j++) {
26 w[j] = w_block[tid * stride_w + j];
27 }
28
29 for (int64_t i = 0; i < n_t; i++) {
30 float sumf = 0.0f;
31
32 if (i == 0) {
33 for (size_t j = 0; j < d_conv; j++) {
34 x[j] = x_block[tid * stride_x + j];
35 }
36 } else {
37 x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
38 }
39
40#pragma unroll
41 for (size_t j = 0; j < d_conv; j++) {
42 sumf += x[(i + j) % d_conv] * w[j];
43 }
44 y_block[i * stride_y + tid] = sumf;
45 }
46}
47
48template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
49static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
50 const int src0_nb0, const int src0_nb1, const int src0_nb2,
51 const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
52 const int dst_nb1, const int dst_nb2, const int64_t n_t) {
53 const int tid = threadIdx.x;
54 const int bidx = blockIdx.x;
55 const int bidy = blockIdx.y;
56 const int bidz = blockIdx.z;
57
58 const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
59 bidz * split_n_t * src0_nb0);
60 const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
61 float * y_block =
62 (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
63
64 const int stride_x = src0_nb1 / sizeof(float);
65 const int stride_w = src1_nb1 / sizeof(float);
66 const int stride_y = dst_nb1 / sizeof(float);
67
68 float x[d_conv] = { 0.0f };
69 float w[d_conv] = { 0.0f };
70
71#pragma unroll
72 for (size_t j = 0; j < d_conv; j++) {
73 w[j] = w_block[tid * stride_w + j];
74 }
75
76#pragma unroll
77 for (int64_t i = 0; i < split_n_t; i++) {
78 if (bidz * split_n_t + i < n_t) {
79 float sumf = 0.0f;
80
81 if (i == 0) {
82 for (size_t j = 0; j < d_conv; j++) {
83 x[j] = x_block[tid * stride_x + j];
84 }
85 } else {
86 x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
87 }
88
89#pragma unroll
90 for (size_t j = 0; j < d_conv; j++) {
91 sumf += x[(i + j) % d_conv] * w[j];
92 }
93 y_block[i * stride_y + tid] = sumf;
94 }
95 }
96}
97
98static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
99 const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
100 const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
101 const int64_t n_s, cudaStream_t stream) {
102 const int threads = 128;
103 GGML_ASSERT(nr % threads == 0);
104
105 auto launch_kernel = [&](auto NC) {
106 constexpr int kNC = decltype(NC)::value;
107 if (n_t <= 32) {
108 const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
109 ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
110 dst, dst_nb0, dst_nb1, dst_nb2, n_t);
111 } else {
112 const int64_t split_n_t = 32;
113 dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
114 ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
115 src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
116 }
117 };
118
119 switch (nc) {
120 case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
121 case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
122 case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
123 default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
124 }
125}
126
127void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
128 const struct ggml_tensor * src0 = dst->src[0]; // conv_x
129 const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
130
131 const int64_t nc = src1->ne[0]; // d_conv
132 const int64_t nr = src0->ne[1]; // d_inner
133 const int64_t n_t = dst->ne[1]; // tokens per sequence
134 const int64_t n_s = dst->ne[2]; // number of sequences in the batch
135
136 GGML_ASSERT(dst->ne[0] == nr);
137 GGML_ASSERT(src0->nb[0] == sizeof(float));
138 GGML_ASSERT(src1->nb[0] == sizeof(float));
139 GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
140
141 const float * src0_d = (const float *) src0->data;
142 const float * src1_d = (const float *) src1->data;
143 float * dst_d = (float *) dst->data;
144 cudaStream_t stream = ctx.stream();
145
146 GGML_ASSERT(src0->type == GGML_TYPE_F32);
147 GGML_ASSERT(dst->type == GGML_TYPE_F32);
148 ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
149 dst->nb[2], nc, nr, n_t, n_s, stream);
150}