1#ifdef USE_FP16
2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
3#define T_FLOAT half
4#define T_FLOAT4 half4
5#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
6#else
7#define T_FLOAT float
8#define T_FLOAT4 float4
9#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
10#endif
11
12#if defined(cl_qcom_reqd_sub_group_size)
13#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
14#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
15#else
16#define REQD_SUBGROUP_SIZE_128
17#endif
18
19#define T_ACCUM float4
20#define VEC_SIZE 4
21
22#define BS_K 64
23#define BS_NPQ 64
24#define BS_CRS 16
25
26#define TS_K 4
27#define TS_NPQ 8
28
29#define WG_K (BS_K / TS_K)
30#define WG_NPQ (BS_NPQ / TS_NPQ)
31
32#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
33#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
34
35static inline uint splitWork(uint work_size, uint block_size){
36 return (work_size + block_size - 1) / block_size;
37}
38
39REQD_SUBGROUP_SIZE_128
40kernel void kernel_conv_2d(
41 global void* p_knl,
42 ulong off_knl,
43 global void* p_src,
44 ulong off_src,
45 global void* p_dst,
46 ulong off_dst,
47 local void* shared,
48 uint Cout, uint Cin, uint N,
49 uint KW, uint KH, uint W, uint H, uint OW, uint OH,
50 uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
51 uint nb01, uint nb02, uint nb03,
52 uint nb11, uint nb12, uint nb13,
53 uint nb1, uint nb2, uint nb3
54) {
55 global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
56 global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
57 global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
58
59 const uint K = Cout;
60 const uint CRS = Cin*KH*KW;
61 const uint NPQ = N*OH*OW;
62
63 const uint lid_k = get_local_id(0);
64 const uint lid_npq = get_local_id(1);
65 const uint tid = lid_npq * WG_K + lid_k;
66
67 const uint B_idx_K = get_group_id(0);
68 const uint B_idx_NPQ = get_group_id(1);
69
70 const uint offset_k = B_idx_K * BS_K;
71 const uint offset_npq = B_idx_NPQ * BS_NPQ;
72
73 local T_FLOAT* Ash = (local T_FLOAT*)shared;
74 local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
75
76 T_ACCUM regC[TS_K][TS_NPQ_VEC];
77 for (int i = 0; i < TS_K; ++i) {
78 for (int j = 0; j < TS_NPQ_VEC; ++j) {
79 regC[i][j] = (T_ACCUM)(0.0f);
80 }
81 }
82
83 const uint NB_CRS = splitWork(CRS, BS_CRS);
84
85 for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
86 const uint offset_crs = B_idx_CRS * BS_CRS;
87
88 for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
89 const uint k_l = i / BS_CRS;
90 const uint crs_l = i % BS_CRS;
91 const uint k_g = offset_k + k_l;
92 const uint crs_g = offset_crs + crs_l;
93
94 if (k_g < K && crs_g < CRS) {
95 const uint Cin_idx = crs_g / (KW*KH);
96 const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
97 const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
98 const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
99 Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
100 } else {
101 Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
102 }
103 }
104
105 for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
106 const uint crs_l = i / BS_NPQ_VEC;
107 const uint npq_l_vec = i % BS_NPQ_VEC;
108 const uint crs_g = offset_crs + crs_l;
109
110 T_FLOAT4 val = (T_FLOAT4)(0.0f);
111 if (crs_g < CRS) {
112 const uint Cin_idx = crs_g / (KW * KH);
113 const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
114 const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
115 for (int v = 0; v < VEC_SIZE; ++v) {
116 const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
117 if (npq_g < NPQ) {
118 const uint N_idx = npq_g / (OH * OW);
119 const uint pq_idx = npq_g % (OH * OW);
120 const uint OH_idx = pq_idx / OW;
121 const uint OW_idx = pq_idx % OW;
122 const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
123 const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
124
125 if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
126 const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
127 ((T_FLOAT*)&val)[v] = src_data[src_idx];
128 }
129 }
130 }
131 }
132 Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
133 }
134
135 barrier(CLK_LOCAL_MEM_FENCE);
136
137 #pragma unroll
138 for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
139 T_FLOAT regA[TS_K];
140 for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
141 regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
142 }
143
144 for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
145 T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
146 for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
147 regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
148 }
149 }
150 }
151 barrier(CLK_LOCAL_MEM_FENCE);
152 }
153
154 for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
155 const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
156 if (k_g >= K) continue;
157
158 for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
159 const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
160
161 const uint N_idx = npq_g_base / (OH * OW);
162 const uint pq_idx = npq_g_base % (OH * OW);
163 const uint OH_idx = pq_idx / OW;
164 const uint OW_idx = pq_idx % OW;
165
166 if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
167 const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
168 VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
169 } else {
170 T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
171 for (int v = 0; v < VEC_SIZE; ++v) {
172 const uint npq_g = npq_g_base + v;
173 if (npq_g < NPQ) {
174 const uint N_idx_s = npq_g / (OH*OW);
175 const uint pq_idx_s = npq_g % (OH*OW);
176 const uint OH_idx_s = pq_idx_s / OW;
177 const uint OW_idx_s = pq_idx_s % OW;
178 const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
179 dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
180 }
181 }
182 }
183 }
184 }
185}