1#include "norm.hpp"
2#include "ggml-sycl/common.hpp"
3#include "ggml-sycl/presets.hpp"
4
5static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
6 const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
7
8 const int nrows = item_ct1.get_group_range(2);
9 const int nchannels = item_ct1.get_group_range(1);
10
11 const int nthreads = item_ct1.get_local_range(2);
12 const int sample = item_ct1.get_group(0);
13 const int channel = item_ct1.get_group(1);
14 const int row = item_ct1.get_group(2);
15
16 const int tid = item_ct1.get_local_id(2);
17 const int nwarps = nthreads / WARP_SIZE;
18
19 const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
20 const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
21
22 x += strided_offset;
23 dst += packed_offset;
24
25 sycl::float2 mean_var = sycl::float2(0.f, 0.f);
26
27 for (int col = tid; col < ncols; col += block_size) {
28 const float xi = x[col];
29 mean_var.x() += xi;
30 mean_var.y() += xi * xi;
31 }
32
33 // sum up partial sums
34 mean_var = warp_reduce_sum(mean_var, item_ct1);
35 if (block_size > WARP_SIZE) {
36 const auto sub_group = item_ct1.get_sub_group();
37 const auto sg_id = sub_group.get_group_linear_id();
38 const auto wi_in_sg = sub_group.get_local_linear_id();
39 if (wi_in_sg == 0) {
40 s_sum[sg_id] = mean_var;
41 }
42 item_ct1.barrier(sycl::access::fence_space::local_space);
43 mean_var = 0.f;
44 const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
45 for (size_t i = 0; i < nreduce; i += 1)
46 {
47 mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
48 }
49 mean_var = warp_reduce_sum(mean_var, item_ct1);
50 }
51
52 const float mean = mean_var.x() / ncols;
53 const float var = mean_var.y() / ncols - mean * mean;
54 const float inv_std = sycl::rsqrt(var + eps);
55
56 for (int col = tid; col < ncols; col += block_size) {
57 dst[col] = (x[col] - mean) * inv_std;
58 }
59}
60
61static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
62 const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
63 int start = item_ct1.get_group(2) * group_size;
64 int end = start + group_size;
65 const int nthreads = item_ct1.get_local_range(2);
66 const int nwarps = nthreads / WARP_SIZE;
67 start += item_ct1.get_local_id(2);
68 size_t nreduce = nwarps / WARP_SIZE;
69
70 if (end >= ne_elements) {
71 end = ne_elements;
72 }
73
74 float tmp = 0.0f; // partial sum for thread in warp
75
76 for (int j = start; j < end; j += block_size) {
77 tmp += x[j];
78 }
79
80 tmp = warp_reduce_sum(tmp, item_ct1);
81 if (block_size > WARP_SIZE) {
82
83 int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
84 int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
85 if (lane_id == 0) {
86 s_sum[warp_id] = tmp;
87 }
88 /*
89 DPCT1118:1: SYCL group functions and algorithms must be encountered in
90 converged control flow. You may need to adjust the code.
91 */
92 /*
93 DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
94 sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
95 better performance if there is no access to global memory.
96 */
97 item_ct1.barrier();
98 tmp = 0.f;
99 for (size_t i = 0; i < nreduce; i += 1)
100 {
101 tmp += s_sum[lane_id + i * WARP_SIZE];
102 }
103 tmp = warp_reduce_sum(tmp, item_ct1);
104 }
105
106 float mean = tmp / group_size;
107 tmp = 0.0f;
108
109 for (int j = start; j < end; j += block_size) {
110 float xi = x[j] - mean;
111 dst[j] = xi;
112 tmp += xi * xi;
113 }
114
115 tmp = warp_reduce_sum(tmp, item_ct1);
116 if (block_size > WARP_SIZE) {
117
118 int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
119 int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
120 if (lane_id == 0) {
121 s_sum[warp_id] = tmp;
122 }
123 /*
124 DPCT1118:2: SYCL group functions and algorithms must be encountered in
125 converged control flow. You may need to adjust the code.
126 */
127 /*
128 DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
129 sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
130 better performance if there is no access to global memory.
131 */
132 item_ct1.barrier();
133 tmp = 0.f;
134 for (size_t i = 0; i < nreduce; i += 1)
135 {
136 tmp += s_sum[lane_id + i * WARP_SIZE];
137 }
138 tmp = warp_reduce_sum(tmp, item_ct1);
139 }
140
141 float variance = tmp / group_size;
142 float scale = sycl::rsqrt(variance + eps);
143 for (int j = start; j < end; j += block_size) {
144 dst[j] *= scale;
145 }
146}
147
148static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
149 const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
150
151 const int nrows = item_ct1.get_group_range(2);
152 const int nchannels = item_ct1.get_group_range(1);
153
154 const int sample = item_ct1.get_group(0);
155 const int channel = item_ct1.get_group(1);
156 const int row = item_ct1.get_group(2);
157
158 const int nthreads = item_ct1.get_local_range(2);
159
160 const int tid = item_ct1.get_local_id(2);
161 const int nwarps = nthreads / WARP_SIZE;
162
163 const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
164 const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
165
166 x += strided_offset;
167 dst += packed_offset;
168
169
170 float tmp = 0.0f; // partial sum for thread in warp
171
172 for (int col = tid; col < ncols; col += block_size) {
173 const float xi = x[col];
174 tmp += xi * xi;
175 }
176
177 // sum up partial sums
178 tmp = warp_reduce_sum(tmp, item_ct1);
179 if (block_size > WARP_SIZE) {
180 const auto sub_group = item_ct1.get_sub_group();
181 const auto sg_id = sub_group.get_group_linear_id();
182 const auto wi_in_sg = sub_group.get_local_linear_id();
183 if (wi_in_sg == 0) {
184 s_sum[sg_id] = tmp;
185 }
186
187 item_ct1.barrier(sycl::access::fence_space::local_space);
188 const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
189 tmp = 0.f;
190 for (size_t i = 0; i < nreduce; i += 1)
191 {
192 tmp += s_sum[wi_in_sg + i * WARP_SIZE];
193 }
194 tmp = warp_reduce_sum(tmp, item_ct1);
195 }
196
197 const float mean = tmp / ncols;
198 const float scale = sycl::rsqrt(mean + eps);
199
200 for (int col = tid; col < ncols; col += block_size) {
201 dst[col] = scale * x[col];
202 }
203}
204
205static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
206 const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
207 const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
208 item_ct1.get_local_id(1);
209 const int tid = item_ct1.get_local_id(2);
210 const int nthreads = item_ct1.get_local_range(2);
211 const int nwarps = nthreads / WARP_SIZE;
212 float tmp = 0.0f; // partial sum for thread in warp
213
214 for (int col = tid; col < ncols; col += block_size) {
215 const float xi = x[row * ncols + col];
216 tmp += xi * xi;
217 }
218
219 // sum up partial sums
220 tmp = warp_reduce_sum(tmp, item_ct1);
221 if (block_size > WARP_SIZE) {
222
223 int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
224 int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
225 if (lane_id == 0) {
226 s_sum[warp_id] = tmp;
227 }
228 /*
229 DPCT1118:3: SYCL group functions and algorithms must be encountered in
230 converged control flow. You may need to adjust the code.
231 */
232 item_ct1.barrier(sycl::access::fence_space::local_space);
233 size_t nreduce = nwarps / WARP_SIZE;
234 tmp = 0.f;
235 for (size_t i = 0; i < nreduce; i += 1)
236 {
237 tmp += s_sum[lane_id + i * WARP_SIZE];
238 }
239 tmp = warp_reduce_sum(tmp, item_ct1);
240 }
241
242 const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
243
244 for (int col = tid; col < ncols; col += block_size) {
245 dst[row * ncols + col] = scale * x[row * ncols + col];
246 }
247}
248
249static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
250 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
251 const float eps, queue_ptr stream, int device) {
252
253 const sycl::range<3> global_dims(nsamples, nchannels, nrows);
254 if (ncols < 1024) {
255 const sycl::range<3> block_dims(1, 1, WARP_SIZE);
256 stream->submit([&](sycl::handler& cgh) {
257 cgh.parallel_for(
258 sycl::nd_range<3>(global_dims * block_dims, block_dims),
259 [=](sycl::nd_item<3> item_ct1)
260 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
261 norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
262 });
263 });
264 }
265 else {
266 const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
267 assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
268 const sycl::range<3> block_dims(1, 1, work_group_size);
269 /*
270 DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
271 the limit. To get the device limit, query
272 info::device::max_work_group_size. Adjust the work-group size if needed.
273 */
274 stream->submit([&](sycl::handler& cgh) {
275 sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
276 sycl::range<1>(work_group_size / WARP_SIZE), cgh);
277 cgh.parallel_for(
278 sycl::nd_range<3>(global_dims * block_dims, block_dims),
279 [=](sycl::nd_item<3> item_ct1)
280 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
281 norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
282 });
283 });
284 }
285}
286
287static void group_norm_f32_sycl(const float* x, float* dst,
288 const int num_groups, const float eps, const int group_size,
289 const int ne_elements, queue_ptr stream, int device) {
290 if (group_size < 1024) {
291 const sycl::range<3> block_dims(1, 1, WARP_SIZE);
292 stream->submit([&](sycl::handler& cgh) {
293 const float eps_ct4 = eps;
294 cgh.parallel_for(
295 sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
296 block_dims),
297 [=](sycl::nd_item<3> item_ct1)
298 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
299 group_norm_f32(
300 x, dst, group_size, ne_elements, eps_ct4, item_ct1,
301 nullptr, WARP_SIZE);
302 });
303 });
304 }
305 else {
306 const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
307 assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
308 const sycl::range<3> block_dims(1, 1, work_group_size);
309 /*
310 DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
311 the limit. To get the device limit, query
312 info::device::max_work_group_size. Adjust the work-group size if needed.
313 */
314
315 stream->submit([&](sycl::handler& cgh) {
316 sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
317 cgh);
318
319 const float eps_ct4 = eps;
320
321 cgh.parallel_for(
322 sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
323 block_dims),
324 [=](sycl::nd_item<3> item_ct1)
325 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
326 group_norm_f32(x, dst, group_size, ne_elements,
327 eps_ct4, item_ct1,
328 get_pointer(s_sum_acc_ct1), work_group_size);
329 });
330 });
331 }
332}
333
334static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
335 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
336 // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
337
338 const sycl::range<3> global_dims(nsamples, nchannels, nrows);
339 if (ncols < 1024) {
340 const sycl::range<3> block_dims(1, 1, WARP_SIZE);
341 stream->submit([&](sycl::handler& cgh) {
342 cgh.parallel_for(
343 sycl::nd_range<3>(global_dims * block_dims, block_dims),
344 [=](sycl::nd_item<3> item_ct1)
345 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
346 rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
347 });
348 });
349 }
350 else {
351 const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
352 assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
353 const sycl::range<3> block_dims(1, 1, work_group_size);
354 /*
355 DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
356 the limit. To get the device limit, query
357 info::device::max_work_group_size. Adjust the work-group size if needed.
358 */
359 stream->submit([&](sycl::handler& cgh) {
360 sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
361 cgh);
362 cgh.parallel_for(
363 sycl::nd_range<3>(global_dims * block_dims, block_dims),
364 [=](sycl::nd_item<3> item_ct1)
365 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
366 rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
367 });
368 });
369 }
370}
371
372static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
373 const int nrows, const float eps,
374 queue_ptr stream, int device) {
375 // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
376 if (ncols < 1024) {
377 const sycl::range<3> block_dims(1, 1, WARP_SIZE);
378 stream->submit([&](sycl::handler& cgh) {
379 cgh.parallel_for(
380 sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
381 block_dims),
382 [=](sycl::nd_item<3> item_ct1)
383 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
384 l2_norm_f32(x, dst, ncols, eps, item_ct1,
385 nullptr, WARP_SIZE);
386 });
387 });
388 }
389 else {
390 const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
391 assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
392 const sycl::range<3> block_dims(1, 1, work_group_size);
393 /*
394 DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
395 the limit. To get the device limit, query
396 info::device::max_work_group_size. Adjust the work-group size if needed.
397 */
398 stream->submit([&](sycl::handler& cgh) {
399 sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
400 cgh);
401 cgh.parallel_for(
402 sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
403 block_dims),
404 [=](sycl::nd_item<3> item_ct1)
405 [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
406 l2_norm_f32(x, dst, ncols, eps, item_ct1,
407 get_pointer(s_sum_acc_ct1), work_group_size);
408 });
409 });
410 }
411}
412
413void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
414 const ggml_tensor * src0 = dst->src[0];
415
416 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
417 GGML_ASSERT(dst->type == GGML_TYPE_F32);
418
419 GGML_TENSOR_UNARY_OP_LOCALS
420 dpct::queue_ptr main_stream = ctx.stream();
421 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
422 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
423 float * dst_dd = static_cast<float *>(dst->data);
424
425 float eps;
426 memcpy(&eps, dst->op_params, sizeof(float));
427 GGML_ASSERT(eps >= 0.0f);
428 const size_t ts0 = ggml_type_size(src0->type);
429 GGML_ASSERT(nb00 == ts0);
430 const int64_t s01 = nb01 / ts0;
431 const int64_t s02 = nb02 / ts0;
432 const int64_t s03 = nb03 / ts0;
433
434 norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
435}
436
437void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
438
439 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
440 GGML_ASSERT(dst->type == GGML_TYPE_F32);
441
442 int num_groups = dst->op_params[0];
443 dpct::queue_ptr main_stream = ctx.stream();
444 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
445
446 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
447 float * dst_dd = static_cast<float *>(dst->data);
448
449 float eps;
450 memcpy(&eps, dst->op_params + 1, sizeof(float));
451
452 int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
453 group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
454}
455
456void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
457
458 const ggml_tensor * src0 = dst->src[0];
459 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
460 GGML_ASSERT(dst->type == GGML_TYPE_F32);
461
462 dpct::queue_ptr main_stream = ctx.stream();
463 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
464
465 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
466 float * dst_dd = static_cast<float *>(dst->data);
467
468 float eps;
469 memcpy(&eps, dst->op_params, sizeof(float));
470
471 GGML_TENSOR_UNARY_OP_LOCALS
472 const size_t ts0 = ggml_type_size(src0->type);
473 GGML_ASSERT(nb00 == ts0);
474 const int64_t s01 = nb01 / ts0;
475 const int64_t s02 = nb02 / ts0;
476 const int64_t s03 = nb03 / ts0;
477 rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
478}
479
480void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
481 scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
482
483 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
484 GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
485 GGML_ASSERT(dst->type == GGML_TYPE_F32);
486
487 float eps = 1e-5f;
488 std::memcpy(&eps, dst->op_params, sizeof(float));
489 if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
490
491 const float * g_base = static_cast<const float *>(dst->src[0]->data); // dz
492 const float * x_base = static_cast<const float *>(dst->src[1]->data); // x
493 float * dx_base = static_cast< float *>(dst->data);
494
495 const int64_t D = dst->ne[0];
496 const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
497 const int64_t N = ggml_nrows(dst);
498 if (D == 0 || N == 0) return;
499
500 const ggml_tensor *G = dst->src[0];
501 const ggml_tensor *X = dst->src[1];
502 const int ts = (int) ggml_type_size(X->type);
503 GGML_ASSERT((size_t) X->nb[0] == (size_t) ts);
504 GGML_ASSERT((size_t) G->nb[0] == (size_t) ts);
505 GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
506
507 const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
508 const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
509 const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
510
511 dpct::queue_ptr q = ctx.stream();
512
513 // work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
514 const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
515 auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
516 int wg_cap = 256;
517 if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
518 int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
519
520 // FP32 path: per-thread compensated accumulation + hierarchical reduction
521 q->submit([&](sycl::handler &cgh) {
522 const int nwarps_loc = std::max(1, WG / WARP_SIZE);
523 // store one partial value per warp (xx and xg) for cross-warp reduction
524 auto l_xx = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
525 auto l_xg = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
526
527 cgh.parallel_for(
528 sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
529 sycl::range<3>(1, 1, WG)),
530 [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
531 const int row = item_ct1.get_group(2);
532 const int tid = item_ct1.get_local_id(2);
533
534 const int64_t i1 = row % n1;
535 const int64_t i2 = (row / n1) % n2;
536 const int64_t i3 = row / (n1 * n2);
537
538 const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
539 const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
540 float *__restrict d_row = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
541
542 // per-thread accumulation (compensated by default)
543 float sum_xx = 0.f, sum_xg = 0.f;
544#ifndef GGML_SYCL_RMS_BACK_FAST
545 float c_xx = 0.f, c_xg = 0.f;
546#endif
547 for (int64_t col = tid; col < D; col += WG) {
548 const float xv = x_row[col];
549 const float gv = g_row[col];
550#ifdef GGML_SYCL_RMS_BACK_FAST
551 sum_xx += xv * xv;
552 sum_xg += xv * gv;
553#else
554 float y1 = xv * xv - c_xx;
555 float t1 = sum_xx + y1;
556 c_xx = (t1 - sum_xx) - y1;
557 sum_xx = t1;
558
559 float y2 = xv * gv - c_xg;
560 float t2 = sum_xg + y2;
561 c_xg = (t2 - sum_xg) - y2;
562 sum_xg = t2;
563#endif
564 }
565
566 // warp-level reduction
567 sycl::float2 xx = sycl::float2(sum_xx,
568#ifndef GGML_SYCL_RMS_BACK_FAST
569 c_xx
570#else
571 0.f
572#endif
573 );
574 sycl::float2 xg = sycl::float2(sum_xg,
575#ifndef GGML_SYCL_RMS_BACK_FAST
576 c_xg
577#else
578 0.f
579#endif
580 );
581 xx = warp_reduce_sum(xx, item_ct1);
582 xg = warp_reduce_sum(xg, item_ct1);
583
584 // cross-warp reduction using local memory (single barrier)
585 const auto sub_group = item_ct1.get_sub_group();
586 const auto sg_id = sub_group.get_group_linear_id();
587 const auto wi_in_sg = sub_group.get_local_linear_id();
588 const int nthreads = item_ct1.get_local_range(2);
589 const int nwarps = nthreads / WARP_SIZE;
590
591 sycl::float2 xx_total = xx;
592 sycl::float2 xg_total = xg;
593 if (nwarps > 1) {
594 if (wi_in_sg == 0) {
595 l_xx[sg_id] = xx;
596 l_xg[sg_id] = xg;
597 }
598 item_ct1.barrier(sycl::access::fence_space::local_space);
599
600 if (sg_id == 0) {
601 const unsigned wi_u = wi_in_sg;
602 sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
603 sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
604 xx_total = warp_reduce_sum(xx_first, item_ct1);
605 xg_total = warp_reduce_sum(xg_first, item_ct1);
606 } else {
607 // other subgroups keep their local totals; they'll be ignored
608 xx_total = xx;
609 xg_total = xg;
610 }
611 // ensure all threads see the first-subgroup result via broadcast below
612 }
613
614 // compute inv_r and coeff once per row and broadcast to the whole work-group
615 float inv_r = 0.f;
616 float coeff = 0.f;
617 if (tid == 0) {
618 const float sum_xx_f = xx_total.x() + xx_total.y();
619 const float sum_xdz_f = xg_total.x() + xg_total.y();
620 const float mean_eps = sum_xx_f / (float) D + eps;
621 const float sum_eps = sum_xx_f + eps * (float) D;
622 inv_r = sycl::rsqrt(mean_eps);
623 coeff = -sum_xdz_f / sum_eps;
624 }
625 inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
626 coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
627
628 for (int64_t col = tid; col < D; col += WG) {
629 d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
630 }
631 });
632 });
633
634}
635
636void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
637
638 GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
639 GGML_ASSERT(dst->type == GGML_TYPE_F32);
640
641 dpct::queue_ptr main_stream = ctx.stream();
642 SYCL_CHECK(ggml_sycl_set_device(ctx.device));
643
644 const int64_t ne00 = dst->src[0]->ne[0];
645 const int64_t nrows = ggml_nrows(dst->src[0]);
646 const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
647 float * dst_dd = static_cast<float *>(dst->data);
648
649 float eps;
650 memcpy(&eps, dst->op_params, sizeof(float));
651
652 l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
653
654}