1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  3
  4#ifdef cl_qcom_reqd_sub_group_size
  5#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  6#define ADRENO_GPU 1
  7#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
  8#endif
  9
 10#ifdef ADRENO_GPU
 11REQD_SUBGROUP_SIZE_128
 12#endif
 13
 14kernel void kernel_mul_mm_q8_0_f32_8x4(
 15        global const uint * src0_q,
 16        global const half  * src0_d,
 17        __read_only image1d_buffer_t src1,
 18        global float * dst,
 19        int k,
 20        int m,
 21        int n,
 22        int n_no_padding,
 23        ulong offsetd
 24) {
 25
 26    int m_4 = m >> 2;
 27    int n_4 = n >> 2;
 28
 29    int gy   = get_global_id(0);
 30    int gx   = get_global_id(1);
 31    int gx_2 = gx << 2;
 32    dst  = (global float *)((global char*)dst  + offsetd);
 33
 34
 35    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0;
 36    half8 B;
 37    half4 deq;
 38
 39    __global const uint* wptr = src0_q + gx_2;
 40    __global const half* sptr = src0_d + gx_2;
 41
 42      for (int i = 0; i < k; i += 4) {
 43        uint4 pack4 = vload4(0, wptr + (i / 4) * m);
 44        half4 scale = vload4(0, sptr + (i / 32) * m);
 45
 46        char4 p0 = as_char4(pack4.s0);
 47        char4 p1 = as_char4(pack4.s1);
 48        char4 p2 = as_char4(pack4.s2);
 49        char4 p3 = as_char4(pack4.s3);
 50
 51        // ------------------- j = 0 (k = i+0) -------------------
 52        B.s0123 = read_imageh(src1, gy * 2 + (i + 0) * n_4);
 53        B.s4567 = read_imageh(src1, gy * 2 + (i + 0) * n_4 + 1);
 54
 55        half4 wj0 = convert_half4((char4)(p0.s0, p1.s0, p2.s0, p3.s0)) * scale;
 56
 57        c0 += B * wj0.s0;
 58        c1 += B * wj0.s1;
 59        c2 += B * wj0.s2;
 60        c3 += B * wj0.s3;
 61
 62        // ------------------- j = 1 (k = i+1) -------------------
 63        B.s0123 = read_imageh(src1, gy * 2 + (i + 1) * n_4);
 64        B.s4567 = read_imageh(src1, gy * 2 + (i + 1) * n_4 + 1);
 65
 66        half4 wj1 = convert_half4((char4)(p0.s1, p1.s1, p2.s1, p3.s1)) * scale;
 67
 68        c0 += B * wj1.s0;
 69        c1 += B * wj1.s1;
 70        c2 += B * wj1.s2;
 71        c3 += B * wj1.s3;
 72
 73        // ------------------- j = 2 (k = i+2) -------------------
 74        B.s0123 = read_imageh(src1, gy * 2 + (i + 2) * n_4);
 75        B.s4567 = read_imageh(src1, gy * 2 + (i + 2) * n_4 + 1);
 76
 77        half4 wj2 = convert_half4((char4)(p0.s2, p1.s2, p2.s2, p3.s2)) * scale;
 78
 79        c0 += B * wj2.s0;
 80        c1 += B * wj2.s1;
 81        c2 += B * wj2.s2;
 82        c3 += B * wj2.s3;
 83
 84        // ------------------- j = 3 (k = i+3) -------------------
 85        B.s0123 = read_imageh(src1, gy * 2 + (i + 3) * n_4);
 86        B.s4567 = read_imageh(src1, gy * 2 + (i + 3) * n_4 + 1);
 87
 88        half4 wj3 = convert_half4((char4)(p0.s3, p1.s3, p2.s3, p3.s3)) * scale;
 89
 90        c0 += B * wj3.s0;
 91        c1 += B * wj3.s1;
 92        c2 += B * wj3.s2;
 93        c3 += B * wj3.s3;
 94    }
 95
 96    int idx = (gy << 3) * m + (gx << 2);
 97
 98    if(idx+3 < m*n_no_padding){
 99        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
100        idx += m;
101    }
102    if(idx+3 < m*n_no_padding){
103        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
104        idx += m;
105    }
106    if(idx+3 < m*n_no_padding){
107        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
108        idx += m;
109    }
110    if(idx+3 < m*n_no_padding){
111        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
112        idx += m;
113    }
114    if(idx+3 < m*n_no_padding){
115        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
116        idx += m;
117    }
118    if(idx+3 < m*n_no_padding){
119        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
120        idx += m;
121    }
122    if(idx+3 < m*n_no_padding){
123        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
124        idx += m;
125    }
126    if(idx+3 < m*n_no_padding){
127        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
128    }
129}