1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#if defined(cl_qcom_reqd_sub_group_size)
  4#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  5#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
  6#else
  7#define REQD_SUBGROUP_SIZE_128
  8#endif
  9
 10#define T_ACCUM float4
 11#define VEC_SIZE 4
 12
 13#define BS_K 64
 14#define BS_NPQ 64
 15#define BS_CRS 16
 16
 17#define TS_K 4
 18#define TS_NPQ 8
 19
 20#define WG_K (BS_K / TS_K)
 21#define WG_NPQ (BS_NPQ / TS_NPQ)
 22
 23#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE)
 24#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE)
 25
 26static inline uint splitWork(uint work_size, uint block_size){
 27    return (work_size + block_size - 1) / block_size;
 28}
 29
 30REQD_SUBGROUP_SIZE_128
 31kernel void kernel_conv_2d(
 32    global void* p_knl,
 33    ulong off_knl,
 34    global void* p_src,
 35    ulong off_src,
 36    global void* p_dst,
 37    ulong off_dst,
 38    local void* shared,
 39    uint Cout, uint Cin, uint N,
 40    uint KW, uint KH, uint W, uint H, uint OW, uint OH,
 41    uint s0, uint s1, uint p0, uint p1, uint d0, uint d1,
 42    uint nb01, uint nb02, uint nb03,
 43    uint nb11, uint nb12, uint nb13,
 44    uint nb1, uint nb2, uint nb3
 45) {
 46    global half* knl_data = (global half*) ((global char*)p_knl + off_knl);
 47    global float* src_data = (global float*) ((global char*)p_src + off_src);
 48    global float* dst_data = (global float*) ((global char*)p_dst + off_dst);
 49
 50    const uint K = Cout;
 51    const uint CRS = Cin*KH*KW;
 52    const uint NPQ = N*OH*OW;
 53
 54    const uint lid_k = get_local_id(0);
 55    const uint lid_npq = get_local_id(1);
 56    const uint tid = lid_npq * WG_K + lid_k;
 57
 58    const uint B_idx_K = get_group_id(0);
 59    const uint B_idx_NPQ = get_group_id(1);
 60
 61    const uint offset_k = B_idx_K * BS_K;
 62    const uint offset_npq = B_idx_NPQ * BS_NPQ;
 63
 64    local half* Ash = (local half*)shared;
 65    local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS];
 66
 67    T_ACCUM regC[TS_K][TS_NPQ_VEC];
 68    for (int i = 0; i < TS_K; ++i) {
 69        for (int j = 0; j < TS_NPQ_VEC; ++j) {
 70            regC[i][j] = (T_ACCUM)(0.0f);
 71        }
 72    }
 73
 74    const uint NB_CRS = splitWork(CRS, BS_CRS);
 75
 76    for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) {
 77        const uint offset_crs = B_idx_CRS * BS_CRS;
 78
 79        for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) {
 80            const uint k_l = i / BS_CRS;
 81            const uint crs_l = i % BS_CRS;
 82            const uint k_g = offset_k + k_l;
 83            const uint crs_g = offset_crs + crs_l;
 84
 85            if (k_g < K && crs_g < CRS) {
 86                const uint Cin_idx = crs_g / (KW*KH);
 87                const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW;
 88                const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW;
 89                const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03;
 90                Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx];
 91            } else {
 92                Ash[k_l * BS_CRS + crs_l] = (half)0.0f;
 93            }
 94        }
 95
 96        for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) {
 97            const uint crs_l = i / BS_NPQ_VEC;
 98            const uint npq_l_vec = i % BS_NPQ_VEC;
 99            const uint crs_g = offset_crs + crs_l;
100
101            float4 val = (float4)(0.0f);
102            if (crs_g < CRS) {
103                const uint Cin_idx = crs_g / (KW * KH);
104                const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW;
105                const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW;
106                for (int v = 0; v < VEC_SIZE; ++v) {
107                    const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v;
108                    if (npq_g < NPQ) {
109                        const uint N_idx = npq_g / (OH * OW);
110                        const uint pq_idx = npq_g % (OH * OW);
111                        const uint OH_idx = pq_idx / OW;
112                        const uint OW_idx = pq_idx % OW;
113                        const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1);
114                        const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0);
115
116                        if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) {
117                            const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13;
118                            ((float*)&val)[v] = src_data[src_idx];
119                        }
120                    }
121                }
122            }
123            Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val;
124        }
125
126        barrier(CLK_LOCAL_MEM_FENCE);
127
128        #pragma unroll
129        for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) {
130            half regA[TS_K];
131            for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
132                regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l];
133            }
134
135            for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
136                float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg];
137                for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
138                    regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]);
139                }
140            }
141        }
142        barrier(CLK_LOCAL_MEM_FENCE);
143    }
144
145    for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) {
146        const uint k_g = offset_k + lid_k * TS_K + k_l_reg;
147        if (k_g >= K) continue;
148
149        for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) {
150            const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE;
151
152            const uint N_idx = npq_g_base / (OH * OW);
153            const uint pq_idx = npq_g_base % (OH * OW);
154            const uint OH_idx = pq_idx / OW;
155            const uint OW_idx = pq_idx % OW;
156
157            if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) {
158                const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3;
159                vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]);
160            } else {
161                T_ACCUM res = regC[k_l_reg][npq_l_vec_reg];
162                for (int v = 0; v < VEC_SIZE; ++v) {
163                    const uint npq_g = npq_g_base + v;
164                    if (npq_g < NPQ) {
165                        const uint N_idx_s = npq_g / (OH*OW);
166                        const uint pq_idx_s = npq_g % (OH*OW);
167                        const uint OH_idx_s = pq_idx_s / OW;
168                        const uint OW_idx_s = pq_idx_s % OW;
169                        const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3;
170                        dst_data[dst_idx_s] = ((float*)&res)[v];
171                    }
172                }
173            }
174        }
175    }
176}