1#ifdef USE_FP16
  2#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  3#define T_FLOAT half
  4#define T_FLOAT4 half4
  5#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p)
  6#else
  7#define T_FLOAT float
  8#define T_FLOAT4 float4
  9#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p)
 10#endif
 11
 12#if defined(cl_qcom_reqd_sub_group_size)
 13#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 14#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 15#else
 16#define REQD_SUBGROUP_SIZE_128
 17#endif
 18
 19#define T_ACCUM float4
 20#define VEC_SIZE 4
 21
 22#define BS_K 64
 23#define BS_NPQ 64
 24#define BS_CRS 16
 25
 26#define TS_K 4
 27#define TS_NPQ 8
 28
 29#define WG_K (BS_K / TS_K)
 30#define WG_NPQ (BS_NPQ / TS_NPQ)
 31
 32#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
 33#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
 34
 35static inline uint splitWork(uint work_size, uint block_size){
 36    return (work_size + block_size - 1) / block_size;
 37}
 38
 39REQD_SUBGROUP_SIZE_128
 40kernel void kernel_conv_2d(
 41    global void* p_knl,
 42    ulong off_knl,
 43    global void* p_src,
 44    ulong off_src,
 45    global void* p_dst,
 46    ulong off_dst,
 47    local void* shared,
 48    uint Cout, uint Cin, uint N,
 49    uint KW, uint KH, uint W, uint H, uint OW, uint OH,
 50    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
 51    uint nb01, uint nb02, uint nb03,
 52    uint nb11, uint nb12, uint nb13,
 53    uint nb1, uint nb2, uint nb3
 54) {
 55    global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl);
 56    global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src);
 57    global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst);
 58
 59    const uint K = Cout;
 60    const uint CRS = Cin*KH*KW;
 61    const uint NPQ = N*OH*OW;
 62
 63    const uint lid_k = get_local_id(0);
 64    const uint lid_npq = get_local_id(1);
 65    const uint tid = lid_npq * WG_K + lid_k;
 66
 67    const uint B_idx_K = get_group_id(0);
 68    const uint B_idx_NPQ = get_group_id(1);
 69
 70    const uint offset_k = B_idx_K * BS_K;
 71    const uint offset_npq = B_idx_NPQ * BS_NPQ;
 72
 73    local T_FLOAT* Ash = (local T_FLOAT*)shared;
 74    local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS];
 75
 76    T_ACCUM regC[TS_K][TS_NPQ_VEC];
 77    for (int i = 0; i < TS_K; ++i) {
 78        for (int j = 0; j < TS_NPQ_VEC; ++j) {
 79            regC[i][j] = (T_ACCUM)(0.0f);
 80        }
 81    }
 82
 83    const uint NB_CRS = splitWork(CRS, BS_CRS);
 84
 85    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
 86        const uint offset_crs = B_idx_CRS * BS_CRS;
 87
 88        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
 89            const uint k_l = i / BS_CRS;
 90            const uint crs_l = i % BS_CRS;
 91            const uint k_g = offset_k + k_l;
 92            const uint crs_g = offset_crs + crs_l;
 93
 94            if (k_g < K && crs_g < CRS) {
 95                const uint Cin_idx = crs_g / (KW*KH);
 96                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
 97                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
 98                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
 99                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
100            } else {
101                Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f;
102            }
103        }
104
105        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
106            const uint crs_l = i / BS_NPQ_VEC;
107            const uint npq_l_vec = i % BS_NPQ_VEC;
108            const uint crs_g = offset_crs + crs_l;
109
110            T_FLOAT4 val = (T_FLOAT4)(0.0f);
111            if (crs_g < CRS) {
112                const uint Cin_idx = crs_g / (KW * KH);
113                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
114                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
115                for (int v = 0; v < VEC_SIZE; ++v) {
116                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
117                    if (npq_g < NPQ) {
118                        const uint N_idx = npq_g / (OH * OW);
119                        const uint pq_idx = npq_g % (OH * OW);
120                        const uint OH_idx = pq_idx / OW;
121                        const uint OW_idx = pq_idx % OW;
122                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
123                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
124
125                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
126                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
127                            ((T_FLOAT*)&val)[v] = src_data[src_idx];
128                        }
129                    }
130                }
131            }
132            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
133        }
134
135        barrier(CLK_LOCAL_MEM_FENCE);
136
137        #pragma unroll
138        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
139            T_FLOAT regA[TS_K];
140            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
141                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
142            }
143
144            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
145                T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
146                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
147                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]);
148                }
149            }
150        }
151        barrier(CLK_LOCAL_MEM_FENCE);
152    }
153
154    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
155        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
156        if (k_g >= K) continue;
157
158        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
159            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
160
161            const uint N_idx = npq_g_base / (OH * OW);
162            const uint pq_idx = npq_g_base % (OH * OW);
163            const uint OH_idx = pq_idx / OW;
164            const uint OW_idx = pq_idx % OW;
165
166            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
167                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
168                VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
169            } else {
170                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
171                for (int v = 0; v < VEC_SIZE; ++v) {
172                    const uint npq_g = npq_g_base + v;
173                    if (npq_g < NPQ) {
174                        const uint N_idx_s = npq_g / (OH*OW);
175                        const uint pq_idx_s = npq_g % (OH*OW);
176                        const uint OH_idx_s = pq_idx_s / OW;
177                        const uint OW_idx_s = pq_idx_s % OW;
178                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
179                        dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]);
180                    }
181                }
182            }
183        }
184    }
185}