1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#if defined(cl_qcom_reqd_sub_group_size)
4#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
5#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
6#else
7#define REQD_SUBGROUP_SIZE_128
8#endif
9
10#define OPWM 64
11#define OPWN 64
12#define CPWK 8
13#define OPTM 4
14#define OPTN 8
15
16#define WG_M (OPWM / OPTM)
17#define WG_N (OPWN / OPTN)
18#define VEC_K (CPWK / 4)
19
20REQD_SUBGROUP_SIZE_128
21__kernel void mul_mat_f16_f32(
22 const int M, const int N, const int K,
23 __global const void* A_void, ulong A_offset,
24 __global const void* B_void, ulong B_offset,
25 __global void* C_void, ulong C_offset) {
26
27 __global const half* A = (__global const half* )((__global const char*)A_void + A_offset);
28 __global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
29 __global float* C = (__global float*)((__global char*)C_void + C_offset);
30
31 const int lidm = get_local_id(0);
32 const int lidn = get_local_id(1);
33 const int lid = lidn * WG_M + lidm;
34
35 const int offsetM = get_group_id(0) * OPWM;
36 const int offsetN = get_group_id(1) * OPWN;
37
38 __local half4 Alocal[OPWM][VEC_K];
39 __local float4 Blocal[OPWN][VEC_K];
40
41 float sum[OPTM][OPTN];
42
43 for (int wm = 0; wm < OPTM; wm++) {
44 for (int wn = 0; wn < OPTN; wn++) {
45 sum[wm][wn] = 0.0f;
46 }
47 }
48
49 const int numTiles = (K + CPWK - 1) / CPWK;
50
51 const int load_row_a = lid % OPWM;
52 const int load_vec_k_a = lid / OPWM;
53 const int global_row_a = offsetM + load_row_a;
54
55 const int load_row_b = lid % OPWN;
56 const int load_vec_k_b = lid / OPWN;
57 const int global_row_b = offsetN + load_row_b;
58
59 for (int t = 0; t < numTiles; t++) {
60 const int k_start = t * CPWK;
61 const int k_vec_start_a = k_start + load_vec_k_a * 4;
62 const int k_vec_start_b = k_start + load_vec_k_b * 4;
63
64 if (global_row_a < M && k_vec_start_a < K) {
65 if (k_vec_start_a + 3 < K) {
66 Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
67 } else {
68 half4 tempA = (half4)(0.0h);
69 if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
70 if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
71 if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
72 Alocal[load_row_a][load_vec_k_a] = tempA;
73 }
74 } else {
75 Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
76 }
77
78 if (global_row_b < N && k_vec_start_b < K) {
79 if (k_vec_start_b + 3 < K) {
80 Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
81 } else {
82 float4 tempB = (float4)(0.0f);
83 if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
84 if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
85 if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
86 Blocal[load_row_b][load_vec_k_b] = tempB;
87 }
88 } else {
89 Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
90 }
91
92 barrier(CLK_LOCAL_MEM_FENCE);
93
94 #pragma unroll
95 for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
96 float4 a_fvecs[OPTM];
97 int current_row_a = lidm;
98 for (int wm = 0; wm < OPTM; wm++) {
99 a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
100 current_row_a += WG_M;
101 }
102
103 float4 b_fvecs[OPTN];
104 int current_row_b = lidn;
105 for (int wn = 0; wn < OPTN; wn++) {
106 b_fvecs[wn] = Blocal[current_row_b][k_vec];
107 current_row_b += WG_N;
108 }
109
110 for (int wm = 0; wm < OPTM; wm++) {
111 for (int wn = 0; wn < OPTN; wn++) {
112 sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
113 }
114 }
115 }
116 barrier(CLK_LOCAL_MEM_FENCE);
117 }
118
119 for (int wm = 0; wm < OPTM; wm++) {
120 int globalRow = offsetM + lidm + wm * WG_M;
121 if (globalRow < M) {
122 for (int wn = 0; wn < OPTN; wn++) {
123 int globalCol = offsetN + lidn + wn * WG_N;
124 if (globalCol < N) {
125 C[globalCol * M + globalRow] = sum[wm][wn];
126 }
127 }
128 }
129 }
130}