1#include "ssm_conv.hpp"
  2#include "common.hpp"
  3
  4#include <cstdio>
  5
  6using namespace sycl;
  7
  8static void kernel_ssm_conv(
  9    queue &q,
 10    const float *src_data,
 11    const float *weights,
 12    float *dst_data,
 13    int d_conv,
 14    int d_inner,
 15    int n_t,
 16    int n_s,
 17    int ncs __attribute__((unused)),
 18    int src_stride_inner,
 19    int src_stride_seq,
 20    int dst_stride_token,
 21    int dst_stride_seq
 22) {
 23    const size_t total_work = static_cast<size_t>(d_inner) * static_cast<size_t>(n_t) * static_cast<size_t>(n_s);
 24    const size_t work_group_size = 256;
 25    const size_t num_work_groups = (total_work + work_group_size - 1) / work_group_size;
 26
 27    const range<1> global_range(num_work_groups * work_group_size);
 28    const range<1> local_range(work_group_size);
 29
 30    q.submit([&](handler &h) {
 31        h.parallel_for(
 32            nd_range<1>(global_range, local_range),
 33            [=](nd_item<1> item) {
 34                const size_t idx = item.get_global_id(0);
 35                if (idx >= total_work) {
 36                    return;
 37                }
 38
 39                const int channel = static_cast<int>(idx % d_inner);
 40                const int token   = static_cast<int>((idx / d_inner) % n_t);
 41                const int seq     = static_cast<int>(idx / (static_cast<size_t>(d_inner) * static_cast<size_t>(n_t)));
 42
 43                const float *s = src_data
 44                    + static_cast<size_t>(seq) * static_cast<size_t>(src_stride_seq)
 45                    + static_cast<size_t>(channel) * static_cast<size_t>(src_stride_inner)
 46                    + static_cast<size_t>(token);
 47
 48                const float *c = weights + static_cast<size_t>(channel) * static_cast<size_t>(d_conv);
 49
 50                float sumf = 0.0f;
 51                for (int i0 = 0; i0 < d_conv; ++i0) {
 52                    sumf += s[i0] * c[i0];
 53                }
 54
 55                const size_t dst_idx =
 56                    static_cast<size_t>(seq) * static_cast<size_t>(dst_stride_seq) +
 57                    static_cast<size_t>(token) * static_cast<size_t>(dst_stride_token) +
 58                    static_cast<size_t>(channel);
 59
 60                dst_data[dst_idx] = sumf;
 61            }
 62        );
 63    });
 64}
 65
 66void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 67    ggml_tensor * src0 = dst->src[0];
 68    ggml_tensor * src1 = dst->src[1];
 69
 70    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 71    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 72    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 73
 74    const int d_conv   = src1->ne[0];
 75    const int ncs      = src0->ne[0];
 76    const int d_inner  = src0->ne[1];
 77    const int n_t      = dst->ne[1];
 78    const int n_s      = dst->ne[2];
 79
 80    GGML_ASSERT(src0->ne[0] == d_conv - 1 + n_t);
 81    GGML_ASSERT(src0->ne[1] == d_inner);
 82    GGML_ASSERT(src1->ne[1] == d_inner);
 83
 84    GGML_ASSERT(dst->ne[0] == d_inner);
 85    GGML_ASSERT(dst->ne[1] == n_t);
 86    GGML_ASSERT(dst->ne[2] == n_s);
 87
 88    GGML_ASSERT(src0->nb[0] == sizeof(float));
 89    GGML_ASSERT(src1->nb[0] == sizeof(float));
 90
 91    GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
 92
 93    const int src_stride_inner = ncs;
 94    const int src_stride_seq   = ncs * d_inner;
 95    const int dst_stride_token = d_inner;
 96    const int dst_stride_seq   = d_inner * n_t;
 97
 98    try {
 99        queue *q = ctx.stream();
100
101        const float *src_data = static_cast<const float *>(src0->data);
102        const float *weights  = static_cast<const float *>(src1->data);
103        float *dst_data       = static_cast<float *>(dst->data);
104
105        GGML_ASSERT(src_data && weights && dst_data);
106
107        kernel_ssm_conv(
108            *q,
109            src_data,
110            weights,
111            dst_data,
112            d_conv,
113            d_inner,
114            n_t,
115            n_s,
116            ncs,
117            src_stride_inner,
118            src_stride_seq,
119            dst_stride_token,
120            dst_stride_seq
121        );
122
123    } catch (const std::exception &e) {
124        std::fprintf(stderr, "[SYCL-SSM_CONV] ERROR: %s\n", e.what());
125        throw;
126    }
127}