1kernel void kernel_ssm_conv_f32_f32(
2 global char * src0,
3 ulong offset0,
4 global char * src1,
5 ulong offset1,
6 global char * dst,
7 ulong offsetd,
8 ulong nb00,
9 ulong nb01,
10 ulong nb02,
11 int ne10,
12 ulong nb11,
13 ulong nb0,
14 ulong nb1,
15 ulong nb2
16){
17 src0 = src0 + offset0;
18 src1 = src1 + offset1;
19 dst = dst + offsetd;
20
21 int ir = get_global_id(0);
22 int i2 = get_global_id(1);
23 int i3 = get_global_id(2);
24
25 int nc = ne10;
26
27 global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
28 global float * c = (global float *) (src1 + ir*nb11);
29 global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
30
31 float sumf = 0.0f;
32
33 for (int i0 = 0; i0 < nc; ++i0) {
34 sumf += s[i0] * c[i0];
35 }
36
37 d[0] = sumf;
38}
39
40kernel void kernel_ssm_conv_f32_f32_4(
41 global char * src0,
42 ulong offset0,
43 global char * src1,
44 ulong offset1,
45 global char * dst,
46 ulong offsetd,
47 ulong nb00,
48 ulong nb01,
49 ulong nb02,
50 int ne10,
51 ulong nb11,
52 ulong nb0,
53 ulong nb1,
54 ulong nb2
55) {
56 src0 = src0 + offset0;
57 src1 = src1 + offset1;
58 dst = dst + offsetd;
59
60 int ir = get_global_id(0);
61 int i2 = get_global_id(1);
62 int i3 = get_global_id(2);
63
64 int nc = ne10;
65
66 global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
67 global float4 * c = (global float4 *) (src1 + ir*nb11);
68 global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
69
70 float sumf = 0.0f;
71
72 for (int i0 = 0; i0 < nc/4; ++i0) {
73 sumf += dot(s[i0], c[i0]);
74 }
75
76 d[0] = sumf;
77}