1#include "softmax.hpp"
  2#include <cstdint>
  3#include <utility>
  4#include <cmath>
  5
  6
  7template <typename T> static __dpct_inline__ float t2f32(T val) {
  8    return (float) val;
  9}
 10
 11template <> float __dpct_inline__ t2f32<sycl::half>(sycl::half val) {
 12  return sycl::vec<sycl::half, 1>(val)
 13      .convert<float, sycl::rounding_mode::automatic>()[0];
 14}
 15
 16struct soft_max_params {
 17
 18    int64_t nheads;
 19    uint32_t n_head_log2;
 20    int64_t ncols;
 21    int64_t nrows_x;
 22    int64_t nrows_y;
 23    int64_t ne00;
 24    int64_t ne01;
 25    int64_t ne02;
 26    int64_t ne03;
 27    int64_t nb11;
 28    int64_t nb12;
 29    int64_t nb13;
 30
 31    int64_t ne12;
 32    int64_t ne13;
 33    float scale;
 34    float max_bias;
 35    float m0;
 36    float m1;
 37};
 38
 39// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
 40// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
 41#ifdef __clang__
 42#pragma clang diagnostic push
 43#pragma clang diagnostic ignored "-Wpass-failed"
 44#endif // __clang__
 45template <bool use_shared, int ncols_template, int block_size_template, typename T>
 46static void soft_max_f32(const float *         x,
 47                         const T *             mask,
 48                         const float *         sinks,
 49                         float *               dst,
 50                         const soft_max_params p,
 51                         uint8_t *             dpct_local) {
 52    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
 53    const int ncols    = ncols_template == 0 ? p.ncols : ncols_template;
 54    const int block_size = block_size_template == 0
 55                               ? item_ct1.get_local_range(2)
 56                               : block_size_template;
 57    const int nthreads = block_size;
 58    const int nwarps = nthreads / WARP_SIZE;
 59    size_t nreduce = nwarps / WARP_SIZE;
 60
 61    const int tid = item_ct1.get_local_id(2);
 62
 63    const int64_t i03 = item_ct1.get_group(0);
 64    const int64_t i02 = item_ct1.get_group(1);
 65    const int64_t i01 = item_ct1.get_group(2);
 66
 67    //TODO: noncontigous inputs/outputs
 68    const int rowx = item_ct1.get_group(2) +
 69                     item_ct1.get_group(1) * item_ct1.get_group_range(2) +
 70                     item_ct1.get_group(0) * item_ct1.get_group_range(2) *
 71                         item_ct1.get_group_range(1);
 72
 73    const int64_t i11 = i01;
 74    const int64_t i12 = i02 % p.ne12;
 75    const int64_t i13 = i03 % p.ne13;
 76
 77    x    += int64_t(rowx)*ncols;
 78    mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
 79    dst  += int64_t(rowx)*ncols;
 80
 81    const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
 82    const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
 83
 84    const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
 85
 86    float * buf_iw = (float *) dpct_local;
 87
 88    // shared memory buffer to cache values between iterations:
 89    float *vals = use_shared ? buf_iw + sycl::max(nwarps, WARP_SIZE) : dst;
 90    float max_val = sinks ? sinks[i02] : -INFINITY;
 91#pragma unroll
 92    for (int col0 = 0; col0 < ncols; col0 += block_size) {
 93        const int col = col0 + tid;
 94
 95        if (ncols_template == 0 && col >= ncols) {
 96            break;
 97        }
 98
 99        const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
100
101        vals[col] = val;
102        max_val   = sycl::max(max_val, val);
103    }
104    // find the max value in the block
105    max_val = warp_reduce_max(max_val);
106
107    if (block_size > WARP_SIZE) {
108        if (warp_id == 0) {
109            buf_iw[lane_id] = -INFINITY;
110        }
111        item_ct1.barrier();
112
113        if (lane_id == 0) {
114            buf_iw[warp_id] = max_val;
115        }
116        item_ct1.barrier();
117
118        max_val = buf_iw[lane_id];
119        max_val = warp_reduce_max(max_val);
120    }
121    float tmp = 0.0f; // partial sum
122
123#pragma unroll
124    for (int col0 = 0; col0 < ncols; col0 += block_size) {
125        const int col = col0 + tid;
126
127        if (ncols_template == 0 && col >= ncols) {
128            break;
129        }
130
131        const float val = sycl::native::exp(vals[col] - max_val);
132        tmp += val;
133        vals[col] = val;
134    }
135    // find the sum of exps in the block
136    tmp = warp_reduce_sum(tmp);
137    if (block_size > WARP_SIZE) {
138        item_ct1.barrier();
139        if (warp_id == 0) {
140            buf_iw[lane_id] = 0.0f;
141            for (size_t i = 1; i < nreduce; i += 1) {
142                buf_iw[lane_id + i * WARP_SIZE] = 0.f;
143            }
144        }
145        item_ct1.barrier();
146
147        if (lane_id == 0) {
148            buf_iw[warp_id] = tmp;
149        }
150        item_ct1.barrier();
151
152        tmp = buf_iw[lane_id];
153        for (size_t i = 1; i < nreduce; i += 1) {
154            tmp += buf_iw[lane_id + i * WARP_SIZE];
155        }
156        tmp = warp_reduce_sum(tmp);
157    }
158    if (sinks) {
159        tmp += sycl::native::exp(sinks[i02] - max_val);
160    }
161    const float inv_sum = 1.0f / tmp;
162
163#pragma unroll
164    for (int col0 = 0; col0 < ncols; col0 += block_size) {
165        const int col = col0 + tid;
166
167        if (ncols_template == 0 && col >= ncols) {
168            return;
169        }
170
171        dst[col] = vals[col] * inv_sum;
172    }
173}
174#ifdef __clang__
175#pragma clang diagnostic pop
176#endif // __clang__
177
178static void soft_max_back_f32(const float *grad, const float *dstf, float *dst,
179                              const int ncols, const float scale) {
180    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
181    const int tid      = item_ct1.get_local_id(2);
182    const int rowx     = item_ct1.get_group(2);
183
184    grad += int64_t(rowx)*ncols;
185    dstf += int64_t(rowx)*ncols;
186    dst  += int64_t(rowx)*ncols;
187
188    float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
189
190    for (int col = tid; col < ncols; col += WARP_SIZE) {
191        dgf_dot += dstf[col]*grad[col];
192    }
193
194    dgf_dot = warp_reduce_sum(dgf_dot);
195
196    for (int col = tid; col < ncols; col += WARP_SIZE) {
197        dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
198    }
199}
200
201template <int... Ns, typename T>
202static void launch_soft_max_kernels(const float *           x,
203                                    const T *               mask,
204                                    const float *           sinks,
205                                    float *                 dst,
206                                    const soft_max_params & p,
207                                    dpct::queue_ptr         stream,
208                                    dpct::dim3              block_dims,
209                                    dpct::dim3              block_nums,
210                                    size_t                  nbytes_shared)
211{
212    auto launch_kernel = [=](auto I) -> bool {
213        constexpr int ncols = decltype(I)::value;
214        constexpr int block = (ncols > 1024 ? 1024 : ncols);
215        if (p.ncols == ncols) {
216            stream->submit([&](sycl::handler &cgh) {
217                sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
218                    sycl::range<1>(nbytes_shared), cgh);
219
220                cgh.parallel_for(
221                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
222                    [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(
223                        WARP_SIZE)]] {
224                        soft_max_f32<true, ncols, block>(
225                            x, mask, sinks, dst, p,
226                            dpct_local_acc_ct1
227                                .get_multi_ptr<sycl::access::decorated::no>()
228                                .get());
229                        GGML_UNUSED(item_ct1);
230                    });
231            });
232            return true;
233        }
234        return false;
235    };
236
237    // unary fold over launch_kernel
238    if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
239        return;
240    }
241
242    stream->submit([&](sycl::handler &cgh) {
243        sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
244            sycl::range<1>(nbytes_shared), cgh);
245
246        cgh.parallel_for(
247            sycl::nd_range<3>(block_nums * block_dims, block_dims),
248            [=](sycl::nd_item<3> item_ct1)
249                [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
250                    soft_max_f32<true, 0, 0>(
251                        x, mask, sinks, dst, p,
252                        dpct_local_acc_ct1
253                            .get_multi_ptr<sycl::access::decorated::no>()
254                            .get());
255                    GGML_UNUSED(item_ct1);
256                });
257    });
258}
259
260template <typename T>
261static void soft_max_f32_sycl(const float *x, const T *mask,
262                              const float *sinks, float *dst,
263                              const soft_max_params &params,
264                              dpct::queue_ptr stream, int device) {
265    int nth = WARP_SIZE;
266    int max_block_size = ggml_sycl_info().max_work_group_sizes[device];
267    const int64_t ncols_x = params.ncols;
268
269    while (nth < ncols_x && nth < max_block_size) nth *= 2;
270    if (nth>max_block_size) nth = max_block_size;
271
272    const dpct::dim3 block_dims(nth, 1, 1);
273    const dpct::dim3 block_nums(params.ne01, params.ne02, params.ne03);
274    const size_t nbytes_shared =
275        (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE) * sizeof(float);
276
277    const int id       = get_current_device_id();
278    const size_t smpbo = ggml_sycl_info().devices[id].smpbo;
279
280    if (nbytes_shared <= smpbo && ncols_x <= max_block_size) {
281        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(
282            x, mask, sinks, dst, params, stream, block_dims, block_nums,
283            nbytes_shared);
284    } else {
285        const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
286
287        stream->submit([&](sycl::handler &cgh) {
288            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(
289                sycl::range<1>(nbytes_shared_low), cgh);
290
291            cgh.parallel_for(
292                sycl::nd_range<3>(block_nums * block_dims, block_dims),
293                [=](sycl::nd_item<3> item_ct1) {
294                    soft_max_f32<false, 0, 0>(
295                        x, mask, sinks, dst, params,
296                        dpct_local_acc_ct1
297                            .get_multi_ptr<sycl::access::decorated::no>()
298                            .get());
299                    GGML_UNUSED(item_ct1);
300                });
301        });
302    }
303}
304
305static void soft_max_back_f32_sycl(const float *   grad,
306                                   const float *   dstf,
307                                   float *         dst,
308                                   const int       ncols,
309                                   const int       nrows,
310                                   const float     scale,
311                                   dpct::queue_ptr stream) {
312    const dpct::dim3 block_dims(WARP_SIZE, 1, 1);
313    const dpct::dim3 block_nums(nrows, 1, 1);
314
315    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
316                         [=](sycl::nd_item<3> item_ct1) {
317                             soft_max_back_f32(grad, dstf, dst, ncols, scale);
318                             GGML_UNUSED(item_ct1);
319                         });
320}
321
322void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
323    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
324
325    const ggml_tensor * src0 = dst->src[0];
326    const ggml_tensor * src1 = dst->src[1];
327    const ggml_tensor * src2 = dst->src[2];
328
329    const float * src0_d = (const float *) src0->data;
330    const void  * src1_d = src1 ? (const void *) src1->data : nullptr;
331    const void  * src2_d = src2 ? (const void *) src2->data : nullptr;
332    float       *  dst_d = (float *) dst->data;
333
334    dpct::queue_ptr stream = ctx.stream();
335
336    GGML_ASSERT(src0->type == GGML_TYPE_F32);
337    GGML_ASSERT( dst->type == GGML_TYPE_F32);
338
339    // src1 contains mask and it is optional
340    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
341
342    const int64_t nrows_x = ggml_nrows(src0);
343    const int64_t nrows_y = src0->ne[1];
344
345    const int64_t ne00 = src0->ne[0];
346
347    float scale    = 1.0f;
348    float max_bias = 0.0f;
349
350    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
351    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
352
353    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
354
355    const int64_t nb11 = src1 ? src1->nb[1] : 1;
356    const int64_t nb12 = src1 ? src1->nb[2] : 1;
357    const int64_t nb13 = src1 ? src1->nb[3] : 1;
358
359    const int64_t ne12 = src1 ? src1->ne[2] : 1;
360    const int64_t ne13 = src1 ? src1->ne[3] : 1;
361
362    const uint32_t n_head      = src0->ne[2];
363    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
364
365    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
366    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
367
368
369    soft_max_params params = {};
370    params.nheads = src0->ne[2];
371    params.n_head_log2 = n_head_log2;
372    params.ncols = ne00;
373    params.nrows_x = nrows_x;
374    params.nrows_y = nrows_y;
375    params.ne00 = src0->ne[0];
376    params.ne01 = src0->ne[1];
377    params.ne02 = src0->ne[2];
378    params.ne03 = src0->ne[3];
379    params.nb11 = nb11;
380    params.nb12 = nb12;
381    params.nb13 = nb13;
382    params.ne12 = ne12;
383    params.ne13 = ne13;
384    params.scale = scale;
385    params.max_bias = max_bias;
386    params.m0 = m0;
387    params.m1 = m1;
388
389    if (use_f16) {
390        soft_max_f32_sycl(src0_d, (const sycl::half *)src1_d,
391                          (const float *)src2_d, dst_d, params, stream,
392                          ctx.device);
393    } else {
394        soft_max_f32_sycl(src0_d, (const float *)src1_d, (const float *)src2_d,
395                          dst_d, params, stream, ctx.device);
396    }
397}
398
399void ggml_sycl_op_soft_max_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
400    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
401    const ggml_tensor * src0 = dst->src[0]; // grad
402    const ggml_tensor * src1 = dst->src[1]; // forward pass output
403
404    const float * src0_d = (const float *) src0->data;
405    const float * src1_d = (const float *) src1->data;
406    float       * dst_d  = (float       *) dst->data;
407
408    dpct::queue_ptr stream = ctx.stream();
409
410    GGML_ASSERT(src0->type == GGML_TYPE_F32);
411    GGML_ASSERT(src1->type == GGML_TYPE_F32);
412    GGML_ASSERT( dst->type == GGML_TYPE_F32);
413
414    const int64_t ncols = src0->ne[0];
415    const int64_t nrows = ggml_nrows(src0);
416
417    float scale    = 1.0f;
418    float max_bias = 0.0f;
419
420    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
421    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
422
423    GGML_ASSERT(max_bias == 0.0f);
424
425    soft_max_back_f32_sycl(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
426}