1kernel void kernel_ssm_conv_f32_f32(
 2    global char * src0,
 3    ulong         offset0,
 4    global char * src1,
 5    ulong         offset1,
 6    global char * dst,
 7    ulong         offsetd,
 8    ulong         nb00,
 9    ulong         nb01,
10    ulong         nb02,
11    int           ne10,
12    ulong         nb11,
13    ulong         nb0,
14    ulong         nb1,
15    ulong         nb2
16){
17    src0 = src0 + offset0;
18    src1 = src1 + offset1;
19    dst  = dst  + offsetd;
20
21    int ir = get_global_id(0);
22    int i2 = get_global_id(1);
23    int i3 = get_global_id(2);
24
25    int nc  = ne10;
26
27    global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
28    global float * c = (global float *) (src1 + ir*nb11);
29    global float * d = (global float *) (dst  + ir*nb0  + i2*nb1  + i3*nb2);
30
31    float sumf = 0.0f;
32
33    for (int i0 = 0; i0 < nc; ++i0) {
34        sumf += s[i0] * c[i0];
35    }
36
37    d[0] = sumf;
38}
39
40kernel void kernel_ssm_conv_f32_f32_4(
41    global char * src0,
42    ulong         offset0,
43    global char * src1,
44    ulong         offset1,
45    global char * dst,
46    ulong         offsetd,
47    ulong         nb00,
48    ulong         nb01,
49    ulong         nb02,
50    int           ne10,
51    ulong         nb11,
52    ulong         nb0,
53    ulong         nb1,
54    ulong         nb2
55) {
56    src0 = src0 + offset0;
57    src1 = src1 + offset1;
58    dst  = dst  + offsetd;
59
60    int ir = get_global_id(0);
61    int i2 = get_global_id(1);
62    int i3 = get_global_id(2);
63
64    int nc = ne10;
65
66    global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
67    global float4 * c = (global float4 *) (src1 + ir*nb11);
68    global float  * d = (global float  *) (dst  + ir*nb0  + i2*nb1  + i3*nb2);
69
70    float sumf = 0.0f;
71
72    for (int i0 = 0; i0 < nc/4; ++i0) {
73        sumf += dot(s[i0], c[i0]);
74    }
75
76    d[0] = sumf;
77}