1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#if defined(cl_qcom_reqd_sub_group_size)
4#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
5#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
6#else
7#define REQD_SUBGROUP_SIZE_128
8#endif
9
10#define T_ACCUM float4
11#define VEC_SIZE 4
12
13#define BS_K 64
14#define BS_NPQ 64
15#define BS_CRS 16
16
17#define TS_K 4
18#define TS_NPQ 8
19
20#define WG_K (BS_K / TS_K)
21#define WG_NPQ (BS_NPQ / TS_NPQ)
22
23#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
24#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
25
26static inline uint splitWork(uint work_size, uint block_size){
27 return (work_size + block_size - 1) / block_size;
28}
29
30REQD_SUBGROUP_SIZE_128
31kernel void kernel_conv_2d(
32 global void* p_knl,
33 ulong off_knl,
34 global void* p_src,
35 ulong off_src,
36 global void* p_dst,
37 ulong off_dst,
38 local void* shared,
39 uint Cout, uint Cin, uint N,
40 uint KW, uint KH, uint W, uint H, uint OW, uint OH,
41 uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
42 uint nb01, uint nb02, uint nb03,
43 uint nb11, uint nb12, uint nb13,
44 uint nb1, uint nb2, uint nb3
45) {
46 global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
47 global float* src_data = (global float*) ((global char*)p_src + off_src);
48 global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
49
50 const uint K = Cout;
51 const uint CRS = Cin*KH*KW;
52 const uint NPQ = N*OH*OW;
53
54 const uint lid_k = get_local_id(0);
55 const uint lid_npq = get_local_id(1);
56 const uint tid = lid_npq * WG_K + lid_k;
57
58 const uint B_idx_K = get_group_id(0);
59 const uint B_idx_NPQ = get_group_id(1);
60
61 const uint offset_k = B_idx_K * BS_K;
62 const uint offset_npq = B_idx_NPQ * BS_NPQ;
63
64 local half* Ash = (local half*)shared;
65 local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
66
67 T_ACCUM regC[TS_K][TS_NPQ_VEC];
68 for (int i = 0; i < TS_K; ++i) {
69 for (int j = 0; j < TS_NPQ_VEC; ++j) {
70 regC[i][j] = (T_ACCUM)(0.0f);
71 }
72 }
73
74 const uint NB_CRS = splitWork(CRS, BS_CRS);
75
76 for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
77 const uint offset_crs = B_idx_CRS * BS_CRS;
78
79 for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
80 const uint k_l = i / BS_CRS;
81 const uint crs_l = i % BS_CRS;
82 const uint k_g = offset_k + k_l;
83 const uint crs_g = offset_crs + crs_l;
84
85 if (k_g < K && crs_g < CRS) {
86 const uint Cin_idx = crs_g / (KW*KH);
87 const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
88 const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
89 const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
90 Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
91 } else {
92 Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
93 }
94 }
95
96 for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
97 const uint crs_l = i / BS_NPQ_VEC;
98 const uint npq_l_vec = i % BS_NPQ_VEC;
99 const uint crs_g = offset_crs + crs_l;
100
101 float4 val = (float4)(0.0f);
102 if (crs_g < CRS) {
103 const uint Cin_idx = crs_g / (KW * KH);
104 const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
105 const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
106 for (int v = 0; v < VEC_SIZE; ++v) {
107 const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
108 if (npq_g < NPQ) {
109 const uint N_idx = npq_g / (OH * OW);
110 const uint pq_idx = npq_g % (OH * OW);
111 const uint OH_idx = pq_idx / OW;
112 const uint OW_idx = pq_idx % OW;
113 const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
114 const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
115
116 if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
117 const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
118 ((float*)&val)[v] = src_data[src_idx];
119 }
120 }
121 }
122 }
123 Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
124 }
125
126 barrier(CLK_LOCAL_MEM_FENCE);
127
128 #pragma unroll
129 for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
130 half regA[TS_K];
131 for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
132 regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
133 }
134
135 for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
136 float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
137 for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
138 regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
139 }
140 }
141 }
142 barrier(CLK_LOCAL_MEM_FENCE);
143 }
144
145 for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
146 const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
147 if (k_g >= K) continue;
148
149 for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
150 const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
151
152 const uint N_idx = npq_g_base / (OH * OW);
153 const uint pq_idx = npq_g_base % (OH * OW);
154 const uint OH_idx = pq_idx / OW;
155 const uint OW_idx = pq_idx % OW;
156
157 if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
158 const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
159 vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
160 } else {
161 T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
162 for (int v = 0; v < VEC_SIZE; ++v) {
163 const uint npq_g = npq_g_base + v;
164 if (npq_g < NPQ) {
165 const uint N_idx_s = npq_g / (OH*OW);
166 const uint pq_idx_s = npq_g % (OH*OW);
167 const uint OH_idx_s = pq_idx_s / OW;
168 const uint OW_idx_s = pq_idx_s % OW;
169 const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
170 dst_data[dst_idx_s] = ((float*)&res)[v];
171 }
172 }
173 }
174 }
175 }
176}