1// src0_q, src0_d, src1 are transposed as a preprocessing step
  2// 4-bit weights are transposed in groups of 4 (unsigned short int)
  3// consider weights originally "next to each other", now "on top of each other"
  4// each fiber computes a 8x4 tile of output elements
  5// using unshuffled weights
  6
  7#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  8#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  9
 10#ifdef cl_qcom_reqd_sub_group_size
 11#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
 12#define ADRENO_GPU 1
 13#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
 14#endif
 15
 16#ifdef ADRENO_GPU
 17REQD_SUBGROUP_SIZE_128
 18#endif
 19
 20kernel void kernel_mul_mat_Ab_Bi_8x4(
 21        global const ushort * src0_q,       // quantized A
 22        global const half  * src0_d,        // A scales
 23        __read_only image1d_buffer_t src1,  // B (1d image)
 24        global float * dst,                 // C
 25        int m,                              // M
 26        int n,                              // N with padding
 27        int k,                              // K
 28        int n_no_padding                    // N without padding
 29) {
 30
 31    int m_4 = m >> 2;
 32    int n_4 = n >> 2;
 33
 34    int gy = get_global_id(0);
 35    int gx = get_global_id(1);
 36    int gx_2 = gx << 2;
 37
 38    half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements
 39    half8 B; // registers for activations
 40    half4 dequantized_weights; // registers for dequantized weights
 41    __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights
 42    __global const half* scale_ptr = src0_d + gx_2; // pointer for scales
 43
 44    for(int i=0; i<k; i+=4){ //loop through K dimension
 45
 46        B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
 47        B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
 48
 49        // keep (i/4) and (i/32) in parenthesis, rounds down
 50        // load 4 consecutive groups of 4 weights
 51        ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); // (i/4) because weights grouped in 4s
 52
 53        // load 4 consecutive scales
 54        half4 scale = vload4(0, scale_ptr + (i/32)*(m));// (i/32) because 1 scale per 32 elements
 55
 56        // j=0
 57        dequantized_weights.s0 = ((bits4.s0 & (0x000F)) - 8) * scale.s0; // dequantize a row of the 16 weights
 58        dequantized_weights.s1 = ((bits4.s1 & (0x000F)) - 8) * scale.s1;
 59        dequantized_weights.s2 = ((bits4.s2 & (0x000F)) - 8) * scale.s2;
 60        dequantized_weights.s3 = ((bits4.s3 & (0x000F)) - 8) * scale.s3;
 61        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
 62        c1 += B * dequantized_weights.s1;
 63        c2 += B * dequantized_weights.s2;
 64        c3 += B * dequantized_weights.s3;
 65
 66        // j=1
 67        B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
 68        B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
 69        dequantized_weights.s0 = (((bits4.s0 & (0x00F0)) >> 4) - 8) * scale.s0; // dequantize a row of the 16 weights
 70        dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1;
 71        dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2;
 72        dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3;
 73        c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate
 74        c1 += B * dequantized_weights.s1;
 75        c2 += B * dequantized_weights.s2;
 76        c3 += B * dequantized_weights.s3;
 77
 78        // j=2
 79        B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
 80        B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
 81        dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights
 82        dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1;
 83        dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2;
 84        dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3;
 85        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
 86        c1 += B * dequantized_weights.s1;
 87        c2 += B * dequantized_weights.s2;
 88        c3 += B * dequantized_weights.s3;
 89
 90        // j=3
 91        B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
 92        B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
 93        dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights
 94        dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1;
 95        dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2;
 96        dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3;
 97        c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
 98        c1 += B * dequantized_weights.s1;
 99        c2 += B * dequantized_weights.s2;
100        c3 += B * dequantized_weights.s3;
101    }
102
103    int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements
104
105    // conditional check if store is to a valid location. Required when N is not a multiple of 8
106    // if statements allow registers to be reused for each store
107    // provides a performance boost due to reduced register footprint, which increases number of concurrent waves
108    if(idx+3 < m*n_no_padding){
109        vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
110        idx += m;
111    }
112    if(idx+3 < m*n_no_padding){
113        vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
114        idx += m;
115    }
116    if(idx+3 < m*n_no_padding){
117        vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
118        idx += m;
119    }
120    if(idx+3 < m*n_no_padding){
121        vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
122        idx += m;
123    }
124    if(idx+3 < m*n_no_padding){
125        vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
126        idx += m;
127    }
128    if(idx+3 < m*n_no_padding){
129        vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
130        idx += m;
131    }
132    if(idx+3 < m*n_no_padding){
133        vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
134        idx += m;
135    }
136    if(idx+3 < m*n_no_padding){
137        vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
138    }
139}