1#include "norm.hpp"
  2#include "ggml-sycl/common.hpp"
  3#include "ggml-sycl/presets.hpp"
  4
  5static void norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
  6        const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, sycl::float2* s_sum, int block_size) {
  7
  8    const int nrows = item_ct1.get_group_range(2);
  9    const int nchannels = item_ct1.get_group_range(1);
 10
 11    const int nthreads = item_ct1.get_local_range(2);
 12    const int sample  = item_ct1.get_group(0);
 13    const int channel = item_ct1.get_group(1);
 14    const int row     = item_ct1.get_group(2);
 15
 16    const int tid = item_ct1.get_local_id(2);
 17    const int nwarps = nthreads / WARP_SIZE;
 18
 19    const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
 20    const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
 21
 22    x += strided_offset;
 23    dst += packed_offset;
 24
 25    sycl::float2 mean_var = sycl::float2(0.f, 0.f);
 26
 27    for (int col = tid; col < ncols; col += block_size) {
 28        const float xi = x[col];
 29        mean_var.x() += xi;
 30        mean_var.y() += xi * xi;
 31    }
 32
 33    // sum up partial sums
 34    mean_var = warp_reduce_sum(mean_var, item_ct1);
 35    if  (block_size > WARP_SIZE) {
 36        const auto sub_group = item_ct1.get_sub_group();
 37        const auto sg_id = sub_group.get_group_linear_id();
 38        const auto wi_in_sg = sub_group.get_local_linear_id();
 39        if (wi_in_sg == 0) {
 40            s_sum[sg_id] = mean_var;
 41        }
 42        item_ct1.barrier(sycl::access::fence_space::local_space);
 43        mean_var = 0.f;
 44        const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
 45        for (size_t i = 0; i < nreduce; i += 1)
 46        {
 47            mean_var += s_sum[wi_in_sg + i * WARP_SIZE];
 48        }
 49        mean_var = warp_reduce_sum(mean_var, item_ct1);
 50    }
 51
 52    const float mean = mean_var.x() / ncols;
 53    const float var = mean_var.y() / ncols - mean * mean;
 54    const float inv_std = sycl::rsqrt(var + eps);
 55
 56    for (int col = tid; col < ncols; col += block_size) {
 57        dst[col] = (x[col] - mean) * inv_std;
 58    }
 59}
 60
 61static void group_norm_f32(const float* x, float* dst, const int group_size, const int ne_elements, const float eps,
 62    const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
 63    int start = item_ct1.get_group(2) * group_size;
 64    int end = start + group_size;
 65    const int nthreads = item_ct1.get_local_range(2);
 66    const int nwarps = nthreads / WARP_SIZE;
 67    start += item_ct1.get_local_id(2);
 68    size_t nreduce = nwarps / WARP_SIZE;
 69
 70    if (end >= ne_elements) {
 71        end = ne_elements;
 72    }
 73
 74    float tmp = 0.0f; // partial sum for thread in warp
 75
 76    for (int j = start; j < end; j += block_size) {
 77        tmp += x[j];
 78    }
 79
 80    tmp = warp_reduce_sum(tmp, item_ct1);
 81    if (block_size > WARP_SIZE) {
 82
 83        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
 84        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
 85        if (lane_id == 0) {
 86            s_sum[warp_id] = tmp;
 87        }
 88        /*
 89        DPCT1118:1: SYCL group functions and algorithms must be encountered in
 90        converged control flow. You may need to adjust the code.
 91        */
 92        /*
 93        DPCT1065:54: Consider replacing sycl::nd_item::barrier() with
 94        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
 95        better performance if there is no access to global memory.
 96        */
 97        item_ct1.barrier();
 98        tmp = 0.f;
 99        for (size_t i = 0; i < nreduce; i += 1)
100        {
101            tmp += s_sum[lane_id + i * WARP_SIZE];
102        }
103        tmp = warp_reduce_sum(tmp, item_ct1);
104    }
105
106    float mean = tmp / group_size;
107    tmp = 0.0f;
108
109    for (int j = start; j < end; j += block_size) {
110        float xi = x[j] - mean;
111        dst[j] = xi;
112        tmp += xi * xi;
113    }
114
115    tmp = warp_reduce_sum(tmp, item_ct1);
116    if (block_size > WARP_SIZE) {
117
118        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
119        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
120        if (lane_id == 0) {
121            s_sum[warp_id] = tmp;
122        }
123        /*
124        DPCT1118:2: SYCL group functions and algorithms must be encountered in
125        converged control flow. You may need to adjust the code.
126        */
127        /*
128        DPCT1065:55: Consider replacing sycl::nd_item::barrier() with
129        sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
130        better performance if there is no access to global memory.
131        */
132        item_ct1.barrier();
133        tmp = 0.f;
134        for (size_t i = 0; i < nreduce; i += 1)
135        {
136            tmp += s_sum[lane_id + i * WARP_SIZE];
137        }
138        tmp = warp_reduce_sum(tmp, item_ct1);
139    }
140
141    float variance = tmp / group_size;
142    float scale = sycl::rsqrt(variance + eps);
143    for (int j = start; j < end; j += block_size) {
144        dst[j] *= scale;
145    }
146}
147
148static void rms_norm_f32(const float* x, float* dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
149        const int64_t stride_sample, const float eps, const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
150
151    const int nrows = item_ct1.get_group_range(2);
152    const int nchannels = item_ct1.get_group_range(1);
153
154    const int sample  = item_ct1.get_group(0);
155    const int channel = item_ct1.get_group(1);
156    const int row     = item_ct1.get_group(2);
157
158    const int nthreads = item_ct1.get_local_range(2);
159
160    const int tid = item_ct1.get_local_id(2);
161    const int nwarps = nthreads / WARP_SIZE;
162
163    const auto strided_offset = calculate_offset<3>({stride_sample, stride_channel, stride_row}, {sample, channel, row});
164    const auto packed_offset = calculate_offset<3>({nchannels * nrows * ncols, nrows * ncols, ncols}, {sample, channel, row});
165
166    x   += strided_offset;
167    dst += packed_offset;
168
169
170    float tmp = 0.0f; // partial sum for thread in warp
171
172    for (int col = tid; col < ncols; col += block_size) {
173        const float xi = x[col];
174        tmp += xi * xi;
175    }
176
177    // sum up partial sums
178    tmp = warp_reduce_sum(tmp, item_ct1);
179    if (block_size > WARP_SIZE) {
180        const auto sub_group = item_ct1.get_sub_group();
181        const auto sg_id = sub_group.get_group_linear_id();
182        const auto wi_in_sg = sub_group.get_local_linear_id();
183        if (wi_in_sg == 0) {
184            s_sum[sg_id] = tmp;
185        }
186
187        item_ct1.barrier(sycl::access::fence_space::local_space);
188        const size_t nreduce = ceil_div(nwarps, WARP_SIZE);
189        tmp = 0.f;
190        for (size_t i = 0; i < nreduce; i += 1)
191        {
192            tmp += s_sum[wi_in_sg + i * WARP_SIZE];
193        }
194        tmp = warp_reduce_sum(tmp, item_ct1);
195    }
196
197    const float mean = tmp / ncols;
198    const float scale = sycl::rsqrt(mean + eps);
199
200    for (int col = tid; col < ncols; col += block_size) {
201        dst[col] = scale * x[col];
202    }
203}
204
205static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps,
206    const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) {
207    const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
208        item_ct1.get_local_id(1);
209    const int tid = item_ct1.get_local_id(2);
210    const int nthreads = item_ct1.get_local_range(2);
211    const int nwarps = nthreads / WARP_SIZE;
212    float tmp = 0.0f; // partial sum for thread in warp
213
214    for (int col = tid; col < ncols; col += block_size) {
215        const float xi = x[row * ncols + col];
216        tmp += xi * xi;
217    }
218
219    // sum up partial sums
220    tmp = warp_reduce_sum(tmp, item_ct1);
221    if (block_size > WARP_SIZE) {
222
223        int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
224        int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
225        if (lane_id == 0) {
226            s_sum[warp_id] = tmp;
227        }
228        /*
229        DPCT1118:3: SYCL group functions and algorithms must be encountered in
230        converged control flow. You may need to adjust the code.
231        */
232        item_ct1.barrier(sycl::access::fence_space::local_space);
233        size_t nreduce = nwarps / WARP_SIZE;
234        tmp = 0.f;
235        for (size_t i = 0; i < nreduce; i += 1)
236        {
237            tmp += s_sum[lane_id + i * WARP_SIZE];
238        }
239        tmp = warp_reduce_sum(tmp, item_ct1);
240    }
241
242    const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps));
243
244    for (int col = tid; col < ncols; col += block_size) {
245        dst[row * ncols + col] = scale * x[row * ncols + col];
246    }
247}
248
249static void norm_f32_sycl(const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
250        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample,
251        const float eps, queue_ptr stream, int device) {
252
253    const sycl::range<3> global_dims(nsamples, nchannels, nrows);
254    if (ncols < 1024) {
255        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
256        stream->submit([&](sycl::handler& cgh) {
257            cgh.parallel_for(
258                sycl::nd_range<3>(global_dims * block_dims, block_dims),
259                [=](sycl::nd_item<3> item_ct1)
260                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
261                    norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
262                });
263            });
264    }
265    else {
266        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
267        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
268        const sycl::range<3> block_dims(1, 1, work_group_size);
269        /*
270        DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
271        the limit. To get the device limit, query
272        info::device::max_work_group_size. Adjust the work-group size if needed.
273        */
274        stream->submit([&](sycl::handler& cgh) {
275            sycl::local_accessor<sycl::float2, 1> s_sum_acc_ct1(
276                            sycl::range<1>(work_group_size / WARP_SIZE), cgh);
277            cgh.parallel_for(
278                sycl::nd_range<3>(global_dims * block_dims, block_dims),
279                [=](sycl::nd_item<3> item_ct1)
280                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
281                    norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
282                });
283            });
284    }
285}
286
287static void group_norm_f32_sycl(const float* x, float* dst,
288    const int num_groups, const float eps, const int group_size,
289    const int ne_elements, queue_ptr stream, int device) {
290    if (group_size < 1024) {
291        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
292        stream->submit([&](sycl::handler& cgh) {
293            const float eps_ct4 = eps;
294            cgh.parallel_for(
295                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
296                    block_dims),
297                [=](sycl::nd_item<3> item_ct1)
298                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
299                    group_norm_f32(
300                        x, dst, group_size, ne_elements, eps_ct4, item_ct1,
301                        nullptr, WARP_SIZE);
302                });
303            });
304    }
305    else {
306        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
307        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
308        const sycl::range<3> block_dims(1, 1, work_group_size);
309        /*
310        DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
311        the limit. To get the device limit, query
312        info::device::max_work_group_size. Adjust the work-group size if needed.
313        */
314
315        stream->submit([&](sycl::handler& cgh) {
316            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
317                cgh);
318
319            const float eps_ct4 = eps;
320
321            cgh.parallel_for(
322                sycl::nd_range<3>(sycl::range<3>(1, 1, num_groups) * block_dims,
323                    block_dims),
324                [=](sycl::nd_item<3> item_ct1)
325                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
326                    group_norm_f32(x, dst, group_size, ne_elements,
327                        eps_ct4, item_ct1,
328                        get_pointer(s_sum_acc_ct1), work_group_size);
329                });
330            });
331    }
332}
333
334static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
335        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) {
336    // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
337
338    const sycl::range<3> global_dims(nsamples, nchannels, nrows);
339    if (ncols < 1024) {
340        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
341        stream->submit([&](sycl::handler& cgh) {
342            cgh.parallel_for(
343                sycl::nd_range<3>(global_dims * block_dims, block_dims),
344                [=](sycl::nd_item<3> item_ct1)
345                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
346                    rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, nullptr, WARP_SIZE);
347                });
348            });
349    }
350    else {
351        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
352        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
353        const sycl::range<3> block_dims(1, 1, work_group_size);
354        /*
355        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
356        the limit. To get the device limit, query
357        info::device::max_work_group_size. Adjust the work-group size if needed.
358        */
359        stream->submit([&](sycl::handler& cgh) {
360            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
361                cgh);
362            cgh.parallel_for(
363                sycl::nd_range<3>(global_dims * block_dims, block_dims),
364                [=](sycl::nd_item<3> item_ct1)
365                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
366                    rms_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size);
367                });
368            });
369    }
370}
371
372static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols,
373    const int nrows, const float eps,
374    queue_ptr stream, int device) {
375    // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE);
376    if (ncols < 1024) {
377        const sycl::range<3> block_dims(1, 1, WARP_SIZE);
378        stream->submit([&](sycl::handler& cgh) {
379            cgh.parallel_for(
380                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
381                    block_dims),
382                [=](sycl::nd_item<3> item_ct1)
383                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
384                    l2_norm_f32(x, dst, ncols, eps, item_ct1,
385                        nullptr, WARP_SIZE);
386                });
387            });
388    }
389    else {
390        const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
391        assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
392        const sycl::range<3> block_dims(1, 1, work_group_size);
393        /*
394        DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
395        the limit. To get the device limit, query
396        info::device::max_work_group_size. Adjust the work-group size if needed.
397        */
398        stream->submit([&](sycl::handler& cgh) {
399            sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE),
400                cgh);
401            cgh.parallel_for(
402                sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims,
403                    block_dims),
404                [=](sycl::nd_item<3> item_ct1)
405                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
406                    l2_norm_f32(x, dst, ncols, eps, item_ct1,
407                        get_pointer(s_sum_acc_ct1), work_group_size);
408                });
409            });
410    }
411}
412
413void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
414    const ggml_tensor * src0 = dst->src[0];
415
416    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
417    GGML_ASSERT(dst->type == GGML_TYPE_F32);
418
419    GGML_TENSOR_UNARY_OP_LOCALS
420    dpct::queue_ptr main_stream = ctx.stream();
421    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
422    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
423    float *       dst_dd  = static_cast<float *>(dst->data);
424
425    float eps;
426    memcpy(&eps, dst->op_params, sizeof(float));
427    GGML_ASSERT(eps >= 0.0f);
428    const size_t ts0 = ggml_type_size(src0->type);
429    GGML_ASSERT(nb00 == ts0);
430    const int64_t s01 = nb01 / ts0;
431    const int64_t s02 = nb02 / ts0;
432    const int64_t s03 = nb03 / ts0;
433
434    norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
435}
436
437void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
438
439    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
440    GGML_ASSERT(dst->type == GGML_TYPE_F32);
441
442    int num_groups = dst->op_params[0];
443    dpct::queue_ptr main_stream = ctx.stream();
444    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
445
446    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
447    float *       dst_dd  = static_cast<float *>(dst->data);
448
449    float eps;
450    memcpy(&eps, dst->op_params + 1, sizeof(float));
451
452    int group_size = dst->src[0]->ne[0] * dst->src[0]->ne[1] * ((dst->src[0]->ne[2] + num_groups - 1) / num_groups);
453    group_norm_f32_sycl(src0_dd, dst_dd, num_groups, eps, group_size, dst->src[0]->ne[0] * dst->src[0]->ne[1] * dst->src[0]->ne[2], main_stream, ctx.device);
454}
455
456void ggml_sycl_op_rms_norm(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
457
458    const ggml_tensor * src0 = dst->src[0];
459    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
460    GGML_ASSERT(dst->type == GGML_TYPE_F32);
461
462    dpct::queue_ptr main_stream = ctx.stream();
463    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
464
465    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
466    float *       dst_dd  = static_cast<float *>(dst->data);
467
468    float eps;
469    memcpy(&eps, dst->op_params, sizeof(float));
470
471    GGML_TENSOR_UNARY_OP_LOCALS
472    const size_t ts0 = ggml_type_size(src0->type);
473    GGML_ASSERT(nb00 == ts0);
474    const int64_t s01 = nb01 / ts0;
475    const int64_t s02 = nb02 / ts0;
476    const int64_t s03 = nb03 / ts0;
477    rms_norm_f32_sycl(src0_dd, dst_dd, ne00, ne01, ne02, ne03, s01, s02, s03, eps, main_stream, ctx.device);
478}
479
480void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
481    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
482
483    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); // dz
484    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); // x
485    GGML_ASSERT(dst->type         == GGML_TYPE_F32);
486
487    float eps = 1e-5f;
488    std::memcpy(&eps, dst->op_params, sizeof(float));
489    if (!(eps > 0.0f) || !std::isfinite(eps)) eps = 1e-5f;
490
491    const float * g_base  = static_cast<const float *>(dst->src[0]->data); // dz
492    const float * x_base  = static_cast<const float *>(dst->src[1]->data); // x
493          float * dx_base = static_cast<      float *>(dst->data);
494
495    const int64_t D  = dst->ne[0];
496    const int64_t n1 = dst->ne[1], n2 = dst->ne[2], n3 = dst->ne[3]; (void) n3;
497    const int64_t N  = ggml_nrows(dst);
498    if (D == 0 || N == 0) return;
499
500    const ggml_tensor *G = dst->src[0];
501    const ggml_tensor *X = dst->src[1];
502    const int ts = (int) ggml_type_size(X->type);
503    GGML_ASSERT((size_t) X->nb[0]   == (size_t) ts);
504    GGML_ASSERT((size_t) G->nb[0]   == (size_t) ts);
505    GGML_ASSERT((size_t) dst->nb[0] == (size_t) ts);
506
507    const int64_t xs1 = X->nb[1] / ts, xs2 = X->nb[2] / ts, xs3 = X->nb[3] / ts;
508    const int64_t gs1 = G->nb[1] / ts, gs2 = G->nb[2] / ts, gs3 = G->nb[3] / ts;
509    const int64_t ds1 = dst->nb[1] / ts, ds2 = dst->nb[2] / ts, ds3 = dst->nb[3] / ts;
510
511    dpct::queue_ptr q = ctx.stream();
512
513    // work-group size: multiple of WARP_SIZE, capped by device and 256, and not larger than D
514    const int device_max_wg = ggml_sycl_info().max_work_group_sizes[ctx.device];
515    auto roundup = [](int v, int m) { return ((v + m - 1) / m) * m; };
516    int wg_cap = 256;
517    if (device_max_wg > 0) wg_cap = std::min(wg_cap, device_max_wg);
518    int WG = std::max(WARP_SIZE, std::min(roundup((int)std::min<int64_t>(D, wg_cap), WARP_SIZE), wg_cap));
519
520    // FP32 path: per-thread compensated accumulation + hierarchical reduction
521    q->submit([&](sycl::handler &cgh) {
522        const int nwarps_loc = std::max(1, WG / WARP_SIZE);
523        // store one partial value per warp (xx and xg) for cross-warp reduction
524        auto l_xx   = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
525        auto l_xg   = sycl::local_accessor<sycl::float2, 1>(sycl::range<1>(nwarps_loc), cgh);
526
527        cgh.parallel_for(
528            sycl::nd_range<3>(sycl::range<3>(1, 1, N) * sycl::range<3>(1, 1, WG),
529                              sycl::range<3>(1, 1, WG)),
530            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
531                const int row = item_ct1.get_group(2);
532                const int tid = item_ct1.get_local_id(2);
533
534                const int64_t i1 = row % n1;
535                const int64_t i2 = (row / n1) % n2;
536                const int64_t i3 = row / (n1 * n2);
537
538                const float *__restrict x_row = x_base + i3 * xs3 + i2 * xs2 + i1 * xs1;
539                const float *__restrict g_row = g_base + i3 * gs3 + i2 * gs2 + i1 * gs1;
540                float *__restrict d_row       = dx_base + i3 * ds3 + i2 * ds2 + i1 * ds1;
541
542                // per-thread accumulation (compensated by default)
543                float sum_xx = 0.f, sum_xg = 0.f;
544#ifndef GGML_SYCL_RMS_BACK_FAST
545                float c_xx = 0.f, c_xg = 0.f;
546#endif
547                for (int64_t col = tid; col < D; col += WG) {
548                    const float xv = x_row[col];
549                    const float gv = g_row[col];
550#ifdef GGML_SYCL_RMS_BACK_FAST
551                    sum_xx += xv * xv;
552                    sum_xg += xv * gv;
553#else
554                    float y1 = xv * xv - c_xx;
555                    float t1 = sum_xx + y1;
556                    c_xx = (t1 - sum_xx) - y1;
557                    sum_xx = t1;
558
559                    float y2 = xv * gv - c_xg;
560                    float t2 = sum_xg + y2;
561                    c_xg = (t2 - sum_xg) - y2;
562                    sum_xg = t2;
563#endif
564                }
565
566                // warp-level reduction
567                sycl::float2 xx = sycl::float2(sum_xx,
568#ifndef GGML_SYCL_RMS_BACK_FAST
569                    c_xx
570#else
571                    0.f
572#endif
573                );
574                sycl::float2 xg = sycl::float2(sum_xg,
575#ifndef GGML_SYCL_RMS_BACK_FAST
576                    c_xg
577#else
578                    0.f
579#endif
580                );
581                xx = warp_reduce_sum(xx, item_ct1);
582                xg = warp_reduce_sum(xg, item_ct1);
583
584                // cross-warp reduction using local memory (single barrier)
585                const auto sub_group = item_ct1.get_sub_group();
586                const auto sg_id     = sub_group.get_group_linear_id();
587                const auto wi_in_sg  = sub_group.get_local_linear_id();
588                const int nthreads   = item_ct1.get_local_range(2);
589                const int nwarps     = nthreads / WARP_SIZE;
590
591                sycl::float2 xx_total = xx;
592                sycl::float2 xg_total = xg;
593                if (nwarps > 1) {
594                    if (wi_in_sg == 0) {
595                        l_xx[sg_id] = xx;
596                        l_xg[sg_id] = xg;
597                    }
598                    item_ct1.barrier(sycl::access::fence_space::local_space);
599
600                    if (sg_id == 0) {
601                        const unsigned wi_u = wi_in_sg;
602                        sycl::float2 xx_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xx[wi_u] : sycl::float2(0.f, 0.f);
603                        sycl::float2 xg_first = (wi_u < static_cast<unsigned>(nwarps)) ? l_xg[wi_u] : sycl::float2(0.f, 0.f);
604                        xx_total = warp_reduce_sum(xx_first, item_ct1);
605                        xg_total = warp_reduce_sum(xg_first, item_ct1);
606                    } else {
607                        // other subgroups keep their local totals; they'll be ignored
608                        xx_total = xx;
609                        xg_total = xg;
610                    }
611                    // ensure all threads see the first-subgroup result via broadcast below
612                }
613
614                // compute inv_r and coeff once per row and broadcast to the whole work-group
615                float inv_r = 0.f;
616                float coeff = 0.f;
617                if (tid == 0) {
618                    const float sum_xx_f  = xx_total.x() + xx_total.y();
619                    const float sum_xdz_f = xg_total.x() + xg_total.y();
620                    const float mean_eps  = sum_xx_f / (float) D + eps;
621                    const float sum_eps   = sum_xx_f + eps * (float) D;
622                    inv_r = sycl::rsqrt(mean_eps);
623                    coeff = -sum_xdz_f / sum_eps;
624                }
625                inv_r = sycl::group_broadcast(item_ct1.get_group(), inv_r);
626                coeff = sycl::group_broadcast(item_ct1.get_group(), coeff);
627
628                for (int64_t col = tid; col < D; col += WG) {
629                    d_row[col] = (g_row[col] + coeff * x_row[col]) * inv_r;
630                }
631            });
632    });
633
634}
635
636void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
637
638    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
639    GGML_ASSERT(dst->type == GGML_TYPE_F32);
640
641    dpct::queue_ptr main_stream = ctx.stream();
642    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
643
644    const int64_t ne00 = dst->src[0]->ne[0];
645    const int64_t nrows = ggml_nrows(dst->src[0]);
646    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
647    float * dst_dd = static_cast<float *>(dst->data);
648
649    float eps;
650    memcpy(&eps, dst->op_params, sizeof(float));
651
652    l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device);
653
654}