1#include "ssm_conv.hpp"
2#include "common.hpp"
3
4#include <cstdio>
5
6using namespace sycl;
7
8static void kernel_ssm_conv(
9 queue &q,
10 const float *src_data,
11 const float *weights,
12 float *dst_data,
13 int d_conv,
14 int d_inner,
15 int n_t,
16 int n_s,
17 int ncs __attribute__((unused)),
18 int src_stride_inner,
19 int src_stride_seq,
20 int dst_stride_token,
21 int dst_stride_seq
22) {
23 const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);
24 const size_t work_group_size = 256;
25 const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;
26
27 const range<1> global_range(num_work_groups * work_group_size);
28 const range<1> local_range(work_group_size);
29
30 q.submit([&](handler &h) {
31 h.parallel_for(
32 nd_range<1>(global_range, local_range),
33 [=](nd_item<1> item) {
34 const size_t idx = item.get_global_id(0);
35 if (idx >= total_work) {
36 return;
37 }
38
39 const int channel = static_cast<int>(idx % d_inner);
40 const int token = static_cast<int>((idx / d_inner) % n_t);
41 const int seq = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));
42
43 const float *s = src_data
44 + static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)
45 + static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)
46 + static_cast<size_t>(token);
47
48 const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);
49
50 float sumf = 0.0f;
51 for (int i0 = 0; i0 < d_conv; ++i0) {
52 sumf += s[i0] * c[i0];
53 }
54
55 const size_t dst_idx =
56 static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +
57 static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +
58 static_cast<size_t>(channel);
59
60 dst_data[dst_idx] = sumf;
61 }
62 );
63 });
64}
65
66void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
67 ggml_tensor * src0 = dst->src[0];
68 ggml_tensor * src1 = dst->src[1];
69
70 GGML_ASSERT(src0->type == GGML_TYPE_F32);
71 GGML_ASSERT(src1->type == GGML_TYPE_F32);
72 GGML_ASSERT(dst->type == GGML_TYPE_F32);
73
74 const int d_conv = src1->ne[0];
75 const int ncs = src0->ne[0];
76 const int d_inner = src0->ne[1];
77 const int n_t = dst->ne[1];
78 const int n_s = dst->ne[2];
79
80 GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
81 GGML_ASSERT(src0->ne[1] == d_inner);
82 GGML_ASSERT(src1->ne[1] == d_inner);
83
84 GGML_ASSERT(dst->ne[0] == d_inner);
85 GGML_ASSERT(dst->ne[1] == n_t);
86 GGML_ASSERT(dst->ne[2] == n_s);
87
88 GGML_ASSERT(src0->nb[0] == sizeof(float));
89 GGML_ASSERT(src1->nb[0] == sizeof(float));
90
91 GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
92
93 const int src_stride_inner = ncs;
94 const int src_stride_seq = ncs * d_inner;
95 const int dst_stride_token = d_inner;
96 const int dst_stride_seq = d_inner * n_t;
97
98 try {
99 queue *q = ctx.stream();
100
101 const float *src_data = static_cast<const float *>(src0->data);
102 const float *weights = static_cast<const float *>(src1->data);
103 float *dst_data = static_cast<float *>(dst->data);
104
105 GGML_ASSERT(src_data && weights && dst_data);
106
107 kernel_ssm_conv(
108 *q,
109 src_data,
110 weights,
111 dst_data,
112 d_conv,
113 d_inner,
114 n_t,
115 n_s,
116 ncs,
117 src_stride_inner,
118 src_stride_seq,
119 dst_stride_token,
120 dst_stride_seq
121 );
122
123 } catch (const std::exception &e) {
124 std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what());
125 throw;
126 }
127}