diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-sycl/wkv.cpp | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-sycl/wkv.cpp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-sycl/wkv.cpp | 293 |
1 files changed, 293 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-sycl/wkv.cpp b/llama.cpp/ggml/src/ggml-sycl/wkv.cpp new file mode 100644 index 0000000..b56e0c2 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-sycl/wkv.cpp | |||
| @@ -0,0 +1,293 @@ | |||
| 1 | #include <sycl/sycl.hpp> | ||
| 2 | #include "wkv.hpp" | ||
| 3 | |||
| 4 | constexpr int WKV_BLOCK_SIZE = 64; | ||
| 5 | |||
| 6 | // Helper function for the main kernel | ||
| 7 | template <int block_size> | ||
| 8 | static void rwkv_wkv6_f32_kernel( | ||
| 9 | const int B, const int T, const int C, const int H, | ||
| 10 | const float* k, const float* v, const float* r, | ||
| 11 | const float* tf, const float* td, const float* s, | ||
| 12 | float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { | ||
| 13 | |||
| 14 | const int tid = item_ct1.get_local_id(2); | ||
| 15 | const int bid = item_ct1.get_group(2); | ||
| 16 | |||
| 17 | const int head_size = block_size; | ||
| 18 | const int batch_i = bid / H; | ||
| 19 | const int head_i = bid % H; | ||
| 20 | const int state_size = C * head_size; | ||
| 21 | const int n_seq_tokens = T / B; | ||
| 22 | |||
| 23 | // Set up shared memory pointers | ||
| 24 | float* _k = shared_mem; | ||
| 25 | float* _r = _k + head_size; | ||
| 26 | float* _tf = _r + head_size; | ||
| 27 | float* _td = _tf + head_size; | ||
| 28 | |||
| 29 | // Local state array | ||
| 30 | float state[block_size]; | ||
| 31 | |||
| 32 | // Load initial state | ||
| 33 | #pragma unroll | ||
| 34 | for (int i = 0; i < head_size; i++) { | ||
| 35 | state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid]; | ||
| 36 | } | ||
| 37 | |||
| 38 | // Sync threads before shared memory operations | ||
| 39 | item_ct1.barrier(sycl::access::fence_space::local_space); | ||
| 40 | |||
| 41 | // Load time-mixing parameters | ||
| 42 | _tf[tid] = tf[head_i * head_size + tid]; | ||
| 43 | item_ct1.barrier(sycl::access::fence_space::local_space); | ||
| 44 | |||
| 45 | // Main sequence processing loop | ||
| 46 | for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; | ||
| 47 | t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; | ||
| 48 | t += C) { | ||
| 49 | |||
| 50 | item_ct1.barrier(sycl::access::fence_space::local_space); | ||
| 51 | |||
| 52 | // Load current timestep data to shared memory | ||
| 53 | _k[tid] = k[t]; | ||
| 54 | _r[tid] = r[t]; | ||
| 55 | _td[tid] = td[t]; | ||
| 56 | |||
| 57 | item_ct1.barrier(sycl::access::fence_space::local_space); | ||
| 58 | |||
| 59 | const float _v = v[t]; | ||
| 60 | float y = 0; | ||
| 61 | |||
| 62 | // Process in chunks of 4 for better vectorization | ||
| 63 | sycl::float4 k4, r4, tf4, td4, s4; | ||
| 64 | #pragma unroll | ||
| 65 | for (int j = 0; j < head_size; j += 4) { | ||
| 66 | // Load data in vec4 chunks | ||
| 67 | k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); | ||
| 68 | r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); | ||
| 69 | tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); | ||
| 70 | td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]); | ||
| 71 | s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); | ||
| 72 | |||
| 73 | // Compute key-value product | ||
| 74 | sycl::float4 kv4 = k4 * _v; | ||
| 75 | |||
| 76 | // Accumulate weighted sum | ||
| 77 | y += sycl::dot(r4, tf4 * kv4 + s4); | ||
| 78 | |||
| 79 | // Update state | ||
| 80 | s4 = s4 * td4 + kv4; | ||
| 81 | |||
| 82 | // Store updated state | ||
| 83 | state[j] = s4.x(); | ||
| 84 | state[j+1] = s4.y(); | ||
| 85 | state[j+2] = s4.z(); | ||
| 86 | state[j+3] = s4.w(); | ||
| 87 | } | ||
| 88 | |||
| 89 | dst[t] = y; | ||
| 90 | } | ||
| 91 | |||
| 92 | // Save final state | ||
| 93 | #pragma unroll | ||
| 94 | for (int i = 0; i < head_size; i++) { | ||
| 95 | dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i]; | ||
| 96 | } | ||
| 97 | } | ||
| 98 | |||
| 99 | template <int block_size> | ||
| 100 | static void rwkv_wkv7_f32_kernel( | ||
| 101 | const int B, const int T, const int C, const int H, | ||
| 102 | const float* r, const float* w, const float* k, const float* v, | ||
| 103 | const float* a, const float* b, const float* s, | ||
| 104 | float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) { | ||
| 105 | |||
| 106 | const int tid = item_ct1.get_local_id(2); | ||
| 107 | const int bid = item_ct1.get_group(2); | ||
| 108 | |||
| 109 | const int head_size = block_size; | ||
| 110 | const int batch_i = bid / H; | ||
| 111 | const int head_i = bid % H; | ||
| 112 | const int state_size = C * head_size; | ||
| 113 | const int n_seq_tokens = T / B; | ||
| 114 | |||
| 115 | float* _r = shared_mem; | ||
| 116 | float* _w = _r + head_size; | ||
| 117 | float* _k = _w + head_size; | ||
| 118 | float* _a = _k + head_size; | ||
| 119 | float* _b = _a + head_size; | ||
| 120 | |||
| 121 | float state[block_size]; | ||
| 122 | |||
| 123 | #pragma unroll | ||
| 124 | for (int i = 0; i < head_size; i++) { | ||
| 125 | state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i]; | ||
| 126 | } | ||
| 127 | |||
| 128 | for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; | ||
| 129 | t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; | ||
| 130 | t += C) { | ||
| 131 | |||
| 132 | item_ct1.barrier(sycl::access::fence_space::local_space); | ||
| 133 | |||
| 134 | _r[tid] = r[t]; | ||
| 135 | _w[tid] = w[t]; | ||
| 136 | _k[tid] = k[t]; | ||
| 137 | _a[tid] = a[t]; | ||
| 138 | _b[tid] = b[t]; | ||
| 139 | |||
| 140 | item_ct1.barrier(sycl::access::fence_space::local_space); | ||
| 141 | |||
| 142 | const float _v = v[t]; | ||
| 143 | float y = 0, sa = 0; | ||
| 144 | sycl::float4 a4, s4; | ||
| 145 | |||
| 146 | #pragma unroll | ||
| 147 | for (int j = 0; j < head_size; j += 4) { | ||
| 148 | a4 = sycl::float4(_a[j], _a[j+1], _a[j+2], _a[j+3]); | ||
| 149 | s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); | ||
| 150 | sa += sycl::dot(a4, s4); | ||
| 151 | } | ||
| 152 | |||
| 153 | sycl::float4 r4, w4, k4, b4; | ||
| 154 | #pragma unroll | ||
| 155 | for (int j = 0; j < head_size; j += 4) { | ||
| 156 | r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]); | ||
| 157 | w4 = sycl::float4(_w[j], _w[j+1], _w[j+2], _w[j+3]); | ||
| 158 | k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); | ||
| 159 | b4 = sycl::float4(_b[j], _b[j+1], _b[j+2], _b[j+3]); | ||
| 160 | s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]); | ||
| 161 | |||
| 162 | sycl::float4 kv4 = k4 * _v; | ||
| 163 | |||
| 164 | s4 = s4 * w4 + kv4 + sa * b4; | ||
| 165 | y += sycl::dot(r4, s4); | ||
| 166 | |||
| 167 | state[j] = s4.x(); | ||
| 168 | state[j+1] = s4.y(); | ||
| 169 | state[j+2] = s4.z(); | ||
| 170 | state[j+3] = s4.w(); | ||
| 171 | } | ||
| 172 | |||
| 173 | dst[t] = y; | ||
| 174 | } | ||
| 175 | |||
| 176 | #pragma unroll | ||
| 177 | for (int i = 0; i < head_size; i++) { | ||
| 178 | dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i]; | ||
| 179 | } | ||
| 180 | } | ||
| 181 | |||
| 182 | void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | ||
| 183 | scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6); | ||
| 184 | const float* k_d = (const float*)dst->src[0]->data; | ||
| 185 | const float* v_d = (const float*)dst->src[1]->data; | ||
| 186 | const float* r_d = (const float*)dst->src[2]->data; | ||
| 187 | const float* tf_d = (const float*)dst->src[3]->data; | ||
| 188 | const float* td_d = (const float*)dst->src[4]->data; | ||
| 189 | const float* s_d = (const float*)dst->src[5]->data; | ||
| 190 | float* dst_d = (float*)dst->data; | ||
| 191 | |||
| 192 | const int64_t B = dst->src[5]->ne[1]; | ||
| 193 | const int64_t T = dst->src[0]->ne[2]; | ||
| 194 | const int64_t C = dst->ne[0]; | ||
| 195 | const int64_t H = dst->src[0]->ne[1]; | ||
| 196 | |||
| 197 | GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32); | ||
| 198 | GGML_ASSERT(C % H == 0); | ||
| 199 | GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64 | ||
| 200 | |||
| 201 | dpct::queue_ptr stream = ctx.stream(); | ||
| 202 | |||
| 203 | // Calculate execution configuration | ||
| 204 | const size_t shared_mem_size = C / H * 4 * sizeof(float); // For k, r, tf, td | ||
| 205 | sycl::range<3> block_dims(1, 1, C / H); | ||
| 206 | sycl::range<3> grid_dims(1, 1, B * H); | ||
| 207 | |||
| 208 | // Submit kernel | ||
| 209 | if (C / H == WKV_BLOCK_SIZE) { | ||
| 210 | stream->submit([&](sycl::handler& cgh) { | ||
| 211 | sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); | ||
| 212 | |||
| 213 | cgh.parallel_for( | ||
| 214 | sycl::nd_range<3>(grid_dims * block_dims, block_dims), | ||
| 215 | [=](sycl::nd_item<3> item_ct1) { | ||
| 216 | rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE>( | ||
| 217 | B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, | ||
| 218 | item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() | ||
| 219 | ); | ||
| 220 | }); | ||
| 221 | }); | ||
| 222 | } else { | ||
| 223 | stream->submit([&](sycl::handler& cgh) { | ||
| 224 | sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); | ||
| 225 | |||
| 226 | cgh.parallel_for( | ||
| 227 | sycl::nd_range<3>(grid_dims * block_dims, block_dims), | ||
| 228 | [=](sycl::nd_item<3> item_ct1) { | ||
| 229 | rwkv_wkv6_f32_kernel<WKV_BLOCK_SIZE * 2>( | ||
| 230 | B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d, | ||
| 231 | item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() | ||
| 232 | ); | ||
| 233 | }); | ||
| 234 | }); | ||
| 235 | } | ||
| 236 | } | ||
| 237 | |||
| 238 | void ggml_sycl_op_rwkv_wkv7(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { | ||
| 239 | scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); | ||
| 240 | const float* r_d = (const float*)dst->src[0]->data; | ||
| 241 | const float* w_d = (const float*)dst->src[1]->data; | ||
| 242 | const float* k_d = (const float*)dst->src[2]->data; | ||
| 243 | const float* v_d = (const float*)dst->src[3]->data; | ||
| 244 | const float* a_d = (const float*)dst->src[4]->data; | ||
| 245 | const float* b_d = (const float*)dst->src[5]->data; | ||
| 246 | const float* s_d = (const float*)dst->src[6]->data; | ||
| 247 | float* dst_d = (float*)dst->data; | ||
| 248 | |||
| 249 | const int64_t B = dst->src[6]->ne[1]; | ||
| 250 | const int64_t T = dst->src[0]->ne[2]; | ||
| 251 | const int64_t C = dst->ne[0]; | ||
| 252 | const int64_t H = dst->src[0]->ne[1]; | ||
| 253 | |||
| 254 | GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32); | ||
| 255 | GGML_ASSERT(C % H == 0); | ||
| 256 | GGML_ASSERT(C / H == WKV_BLOCK_SIZE || C / H == WKV_BLOCK_SIZE * 2); | ||
| 257 | |||
| 258 | dpct::queue_ptr stream = ctx.stream(); | ||
| 259 | |||
| 260 | // Calculate execution configuration | ||
| 261 | const size_t shared_mem_size = C / H * 5 * sizeof(float); // For r, w, k, a, b | ||
| 262 | sycl::range<3> block_dims(1, 1, C / H); | ||
| 263 | sycl::range<3> grid_dims(1, 1, B * H); | ||
| 264 | |||
| 265 | // Submit kernel | ||
| 266 | if (C / H == WKV_BLOCK_SIZE) { | ||
| 267 | stream->submit([&](sycl::handler& cgh) { | ||
| 268 | sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); | ||
| 269 | |||
| 270 | cgh.parallel_for( | ||
| 271 | sycl::nd_range<3>(grid_dims * block_dims, block_dims), | ||
| 272 | [=](sycl::nd_item<3> item_ct1) { | ||
| 273 | rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE>( | ||
| 274 | B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, | ||
| 275 | item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() | ||
| 276 | ); | ||
| 277 | }); | ||
| 278 | }); | ||
| 279 | } else { | ||
| 280 | stream->submit([&](sycl::handler& cgh) { | ||
| 281 | sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh); | ||
| 282 | |||
| 283 | cgh.parallel_for( | ||
| 284 | sycl::nd_range<3>(grid_dims * block_dims, block_dims), | ||
| 285 | [=](sycl::nd_item<3> item_ct1) { | ||
| 286 | rwkv_wkv7_f32_kernel<WKV_BLOCK_SIZE * 2>( | ||
| 287 | B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d, | ||
| 288 | item_ct1, (float*)shared_mem_acc.get_multi_ptr<sycl::access::decorated::no>().get() | ||
| 289 | ); | ||
| 290 | }); | ||
| 291 | }); | ||
| 292 | } | ||
| 293 | } | ||
