1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#ifdef cl_intel_subgroups
4#pragma OPENCL EXTENSION cl_intel_subgroups : enable
5#else
6#pragma OPENCL EXTENSION cl_khr_subgroups : enable
7#endif
8
9#ifdef cl_intel_required_subgroup_size
10#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
11#define INTEL_GPU 1
12#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
13#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
14#elif defined(cl_qcom_reqd_sub_group_size)
15#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
16#define ADRENO_GPU 1
17#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
18#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
19#endif
20
21#define N_F32_F32 4
22
23#ifdef ADRENO_GPU
24REQD_SUBGROUP_SIZE_64
25#endif
26kernel void kernel_mul_mat_f32_f32(
27 global char * src0,
28 ulong offset0,
29 global char * src1,
30 ulong offset1,
31 global float * dst,
32 ulong offsetd,
33 int ne00,
34 int ne01,
35 int ne02,
36 ulong nb00,
37 ulong nb01,
38 ulong nb02,
39 ulong nb03,
40 int ne10,
41 int ne11,
42 int ne12,
43 ulong nb10,
44 ulong nb11,
45 ulong nb12,
46 ulong nb13,
47 int ne0,
48 int ne1,
49 int r2,
50 int r3
51) {
52 src0 = (global char*)((global char*)src0 + offset0);
53 src1 = (global char*)((global char*)src1 + offset1);
54 dst = (global float*)((global char*)dst + offsetd);
55
56 int r0 = get_group_id(0);
57 int rb = get_group_id(1)*N_F32_F32;
58 int im = get_group_id(2);
59
60 int i12 = im%ne12;
61 int i13 = im/ne12;
62
63 ulong offset_src0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
64
65 global float * x = (global float *) (src0 + offset_src0);
66
67 if (ne00 < 128) {
68 for (int row = 0; row < N_F32_F32; ++row) {
69 int r1 = rb + row;
70 if (r1 >= ne11) {
71 break;
72 }
73
74 ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
75
76 global float * y = (global float *) (src1 + offset_src1);
77
78 float sumf = 0;
79 for (int i = get_sub_group_local_id(); i < ne00; i += get_max_sub_group_size()) {
80 sumf += (float) x[i] * (float) y[i];
81 }
82
83 float all_sum = sub_group_reduce_add(sumf);
84 if (get_sub_group_local_id() == 0) {
85 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
86 }
87 }
88 } else {
89 global float4 * x4 = (global float4 *)x;
90 for (int row = 0; row < N_F32_F32; ++row) {
91 int r1 = rb + row;
92 if (r1 >= ne11) {
93 break;
94 }
95
96 ulong offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
97
98 global float * y = (global float *) (src1 + offset_src1);
99 global float4 * y4 = (global float4 *) y;
100
101 float sumf = 0;
102 for (int i = get_sub_group_local_id(); i < ne00/4; i += get_max_sub_group_size()) {
103 sumf += (float) x4[i].s0 * y4[i].s0;
104 sumf += (float) x4[i].s1 * y4[i].s1;
105 sumf += (float) x4[i].s2 * y4[i].s2;
106 sumf += (float) x4[i].s3 * y4[i].s3;
107 }
108
109 float all_sum = sub_group_reduce_add(sumf);
110 if (get_sub_group_local_id() == 0) {
111 for (int i = 4*(ne00/4); i < ne00; ++i) {
112 all_sum += (float) x[i] * y[i];
113 }
114 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
115 }
116 }
117 }
118}