1#include <sycl/sycl.hpp>
  2#include "wkv.hpp"
  3
  4constexpr int WKV_BLOCK_SIZE = 64;
  5
  6// Helper function for the main kernel
  7template <int block_size>
  8static void rwkv_wkv6_f32_kernel(
  9    const int B, const int T, const int C, const int H,
 10    const float* k, const float* v, const float* r,
 11    const float* tf, const float* td, const float* s,
 12    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
 13
 14    const int tid = item_ct1.get_local_id(2);
 15    const int bid = item_ct1.get_group(2);
 16
 17    const int head_size = block_size;
 18    const int batch_i = bid / H;
 19    const int head_i = bid % H;
 20    const int state_size = C * head_size;
 21    const int n_seq_tokens = T / B;
 22
 23    // Set up shared memory pointers
 24    float* _k = shared_mem;
 25    float* _r = _k + head_size;
 26    float* _tf = _r + head_size;
 27    float* _td = _tf + head_size;
 28
 29    // Local state array
 30    float state[block_size];
 31
 32    // Load initial state
 33    #pragma unroll
 34    for (int i = 0; i < head_size; i++) {
 35        state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
 36    }
 37
 38    // Sync threads before shared memory operations
 39    item_ct1.barrier(sycl::access::fence_space::local_space);
 40
 41    // Load time-mixing parameters
 42    _tf[tid] = tf[head_i * head_size + tid];
 43    item_ct1.barrier(sycl::access::fence_space::local_space);
 44
 45    // Main sequence processing loop
 46    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
 47         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
 48         t += C) {
 49
 50        item_ct1.barrier(sycl::access::fence_space::local_space);
 51
 52        // Load current timestep data to shared memory
 53        _k[tid] = k[t];
 54        _r[tid] = r[t];
 55        _td[tid] = td[t];
 56
 57        item_ct1.barrier(sycl::access::fence_space::local_space);
 58
 59        const float _v = v[t];
 60        float y = 0;
 61
 62        // Process in chunks of 4 for better vectorization
 63        sycl::float4 k4, r4, tf4, td4, s4;
 64        #pragma unroll
 65        for (int j = 0; j < head_size; j += 4) {
 66            // Load data in vec4 chunks
 67            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
 68            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
 69            tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
 70            td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
 71            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
 72
 73            // Compute key-value product
 74            sycl::float4 kv4 = k4 * _v;
 75
 76            // Accumulate weighted sum
 77            y += sycl::dot(r4, tf4 * kv4 + s4);
 78
 79            // Update state
 80            s4 = s4 * td4 + kv4;
 81
 82            // Store updated state
 83            state[j] = s4.x();
 84            state[j+1] = s4.y();
 85            state[j+2] = s4.z();
 86            state[j+3] = s4.w();
 87        }
 88
 89        dst[t] = y;
 90    }
 91
 92    // Save final state
 93    #pragma unroll
 94    for (int i = 0; i < head_size; i++) {
 95        dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
 96    }
 97}
 98
 99template <int block_size>
100static void rwkv_wkv7_f32_kernel(
101    const int B, const int T, const int C, const int H,
102    const float* r, const float* w, const float* k, const float* v,
103    const float* a, const float* b, const float* s,
104    float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
105
106    const int tid = item_ct1.get_local_id(2);
107    const int bid = item_ct1.get_group(2);
108
109    const int head_size = block_size;
110    const int batch_i = bid / H;
111    const int head_i = bid % H;
112    const int state_size = C * head_size;
113    const int n_seq_tokens = T / B;
114
115    float* _r = shared_mem;
116    float* _w = _r + head_size;
117    float* _k = _w + head_size;
118    float* _a = _k + head_size;
119    float* _b = _a + head_size;
120
121    float state[block_size];
122
123    #pragma unroll
124    for (int i = 0; i < head_size; i++) {
125        state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
126    }
127
128    for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
129         t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
130         t += C) {
131
132        item_ct1.barrier(sycl::access::fence_space::local_space);
133
134        _r[tid] = r[t];
135        _w[tid] = w[t];
136        _k[tid] = k[t];
137        _a[tid] = a[t];
138        _b[tid] = b[t];
139
140        item_ct1.barrier(sycl::access::fence_space::local_space);
141
142        const float _v = v[t];
143        float y = 0, sa = 0;
144        sycl::float4 a4, s4;
145
146        #pragma unroll
147        for (int j = 0; j < head_size; j += 4) {
148            a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
149            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
150            sa += sycl::dot(a4, s4);
151        }
152
153        sycl::float4 r4, w4, k4, b4;
154        #pragma unroll
155        for (int j = 0; j < head_size; j += 4) {
156            r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
157            w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]);
158            k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
159            b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]);
160            s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
161
162            sycl::float4 kv4 = k4 * _v;
163
164            s4 = s4 * w4 + kv4 + sa * b4;
165            y += sycl::dot(r4, s4);
166
167            state[j] = s4.x();
168            state[j+1] = s4.y();
169            state[j+2] = s4.z();
170            state[j+3] = s4.w();
171        }
172
173        dst[t] = y;
174    }
175
176    #pragma unroll
177    for (int i = 0; i < head_size; i++) {
178        dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
179    }
180}
181
182void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
183    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6);
184    const float* k_d = (const float*)dst->src[0]->data;
185    const float* v_d = (const float*)dst->src[1]->data;
186    const float* r_d = (const float*)dst->src[2]->data;
187    const float* tf_d = (const float*)dst->src[3]->data;
188    const float* td_d = (const float*)dst->src[4]->data;
189    const float* s_d = (const float*)dst->src[5]->data;
190    float* dst_d = (float*)dst->data;
191
192    const int64_t B = dst->src[5]->ne[1];
193    const int64_t T = dst->src[0]->ne[2];
194    const int64_t C = dst->ne[0];
195    const int64_t H = dst->src[0]->ne[1];
196
197    GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
198    GGML_ASSERT(C % H == 0);
199    GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
200
201    dpct::queue_ptr stream = ctx.stream();
202
203    // Calculate execution configuration
204    const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td
205    sycl::range<3> block_dims(1, 1, C / H);
206    sycl::range<3> grid_dims(1, 1, B * H);
207
208    // Submit kernel
209    if (C / H == WKV_BLOCK_SIZE) {
210        stream->submit([&](sycl::handler& cgh) {
211            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
212
213            cgh.parallel_for(
214                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
215                [=](sycl::nd_item<3> item_ct1) {
216                    rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>(
217                        B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
218                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
219                    );
220                });
221        });
222    } else {
223        stream->submit([&](sycl::handler& cgh) {
224            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
225
226            cgh.parallel_for(
227                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
228                [=](sycl::nd_item<3> item_ct1) {
229                    rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>(
230                        B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
231                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
232                    );
233                });
234        });
235    }
236}
237
238void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
239    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7);
240    const float* r_d = (const float*)dst->src[0]->data;
241    const float* w_d = (const float*)dst->src[1]->data;
242    const float* k_d = (const float*)dst->src[2]->data;
243    const float* v_d = (const float*)dst->src[3]->data;
244    const float* a_d = (const float*)dst->src[4]->data;
245    const float* b_d = (const float*)dst->src[5]->data;
246    const float* s_d = (const float*)dst->src[6]->data;
247    float* dst_d = (float*)dst->data;
248
249    const int64_t B = dst->src[6]->ne[1];
250    const int64_t T = dst->src[0]->ne[2];
251    const int64_t C = dst->ne[0];
252    const int64_t H = dst->src[0]->ne[1];
253
254    GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
255    GGML_ASSERT(C % H == 0);
256    GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2);
257
258    dpct::queue_ptr stream = ctx.stream();
259
260    // Calculate execution configuration
261    const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b
262    sycl::range<3> block_dims(1, 1, C / H);
263    sycl::range<3> grid_dims(1, 1, B * H);
264
265    // Submit kernel
266    if (C / H == WKV_BLOCK_SIZE) {
267        stream->submit([&](sycl::handler& cgh) {
268            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
269
270            cgh.parallel_for(
271                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
272                [=](sycl::nd_item<3> item_ct1) {
273                    rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>(
274                        B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
275                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
276                    );
277                });
278        });
279    } else {
280        stream->submit([&](sycl::handler& cgh) {
281            sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
282
283            cgh.parallel_for(
284                sycl::nd_range<3>(grid_dims * block_dims, block_dims),
285                [=](sycl::nd_item<3> item_ct1) {
286                    rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>(
287                        B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d,
288                        item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get()
289                    );
290                });
291        });
292    }
293}