1// src0_q, src0_d, src1 are transposed as a preprocessing step
2// 4-bit weights are transposed in groups of 4 (unsigned short int)
3// consider weights originally "next to each other", now "on top of each other"
4// each fiber computes a 8x4 tile of output elements
5// using unshuffled weights
6
7#pragma OPENCL EXTENSION cl_khr_fp16 : enable
8#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
9
10#ifdef cl_qcom_reqd_sub_group_size
11#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
12#define ADRENO_GPU 1
13#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
14#endif
15
16#ifdef ADRENO_GPU
17REQD_SUBGROUP_SIZE_128
18#endif
19
20kernel void kernel_mul_mat_Ab_Bi_8x4(
21 global const ushort * src0_q, // quantized A
22 global const half * src0_d, // A scales
23 __read_only image1d_buffer_t src1, // B (1d image)
24 global float * dst, // C
25 int m, // M
26 int n, // N with padding
27 int k, // K
28 int n_no_padding // N without padding
29) {
30
31 int m_4 = m >> 2;
32 int n_4 = n >> 2;
33
34 int gy = get_global_id(0);
35 int gx = get_global_id(1);
36 int gx_2 = gx << 2;
37
38 half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; // 8x4 output elements
39 half8 B; // registers for activations
40 half4 dequantized_weights; // registers for dequantized weights
41 __global const ushort* weight_ptr = src0_q + gx_2; // pointer for weights
42 __global const half* scale_ptr = src0_d + gx_2; // pointer for scales
43
44 for(int i=0; i<k; i+=4){ //loop through K dimension
45
46 B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4));
47 B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1);
48
49 // keep (i/4) and (i/32) in parenthesis, rounds down
50 // load 4 consecutive groups of 4 weights
51 ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); // (i/4) because weights grouped in 4s
52
53 // load 4 consecutive scales
54 half4 scale = vload4(0, scale_ptr + (i/32)*(m));// (i/32) because 1 scale per 32 elements
55
56 // j=0
57 dequantized_weights.s0 = ((bits4.s0 & (0x000F)) - 8) * scale.s0; // dequantize a row of the 16 weights
58 dequantized_weights.s1 = ((bits4.s1 & (0x000F)) - 8) * scale.s1;
59 dequantized_weights.s2 = ((bits4.s2 & (0x000F)) - 8) * scale.s2;
60 dequantized_weights.s3 = ((bits4.s3 & (0x000F)) - 8) * scale.s3;
61 c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
62 c1 += B * dequantized_weights.s1;
63 c2 += B * dequantized_weights.s2;
64 c3 += B * dequantized_weights.s3;
65
66 // j=1
67 B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4));
68 B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1);
69 dequantized_weights.s0 = (((bits4.s0 & (0x00F0)) >> 4) - 8) * scale.s0; // dequantize a row of the 16 weights
70 dequantized_weights.s1 = (((bits4.s1 & (0x00F0)) >> 4) - 8) * scale.s1;
71 dequantized_weights.s2 = (((bits4.s2 & (0x00F0)) >> 4) - 8) * scale.s2;
72 dequantized_weights.s3 = (((bits4.s3 & (0x00F0)) >> 4) - 8) * scale.s3;
73 c0 += B * dequantized_weights.s0; //vector-scalar multiplication to accumulate
74 c1 += B * dequantized_weights.s1;
75 c2 += B * dequantized_weights.s2;
76 c3 += B * dequantized_weights.s3;
77
78 // j=2
79 B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4));
80 B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1);
81 dequantized_weights.s0 = (((bits4.s0 & (0x0F00)) >> 8) - 8) * scale.s0; // dequantize a row of the 16 weights
82 dequantized_weights.s1 = (((bits4.s1 & (0x0F00)) >> 8) - 8) * scale.s1;
83 dequantized_weights.s2 = (((bits4.s2 & (0x0F00)) >> 8) - 8) * scale.s2;
84 dequantized_weights.s3 = (((bits4.s3 & (0x0F00)) >> 8) - 8) * scale.s3;
85 c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
86 c1 += B * dequantized_weights.s1;
87 c2 += B * dequantized_weights.s2;
88 c3 += B * dequantized_weights.s3;
89
90 // j=3
91 B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4));
92 B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1);
93 dequantized_weights.s0 = (((bits4.s0 & (0xF000)) >> 12) - 8) * scale.s0; // dequantize a row of the 16 weights
94 dequantized_weights.s1 = (((bits4.s1 & (0xF000)) >> 12) - 8) * scale.s1;
95 dequantized_weights.s2 = (((bits4.s2 & (0xF000)) >> 12) - 8) * scale.s2;
96 dequantized_weights.s3 = (((bits4.s3 & (0xF000)) >> 12) - 8) * scale.s3;
97 c0 += B * dequantized_weights.s0; // vector-scalar multiplication to accumulate
98 c1 += B * dequantized_weights.s1;
99 c2 += B * dequantized_weights.s2;
100 c3 += B * dequantized_weights.s3;
101 }
102
103 int idx = (gy<<3)*m + (gx<<2); // vectorized store 16 elements
104
105 // conditional check if store is to a valid location. Required when N is not a multiple of 8
106 // if statements allow registers to be reused for each store
107 // provides a performance boost due to reduced register footprint, which increases number of concurrent waves
108 if(idx+3 < m*n_no_padding){
109 vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx);
110 idx += m;
111 }
112 if(idx+3 < m*n_no_padding){
113 vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx);
114 idx += m;
115 }
116 if(idx+3 < m*n_no_padding){
117 vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx);
118 idx += m;
119 }
120 if(idx+3 < m*n_no_padding){
121 vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx);
122 idx += m;
123 }
124 if(idx+3 < m*n_no_padding){
125 vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx);
126 idx += m;
127 }
128 if(idx+3 < m*n_no_padding){
129 vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx);
130 idx += m;
131 }
132 if(idx+3 < m*n_no_padding){
133 vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx);
134 idx += m;
135 }
136 if(idx+3 < m*n_no_padding){
137 vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx);
138 }
139}