1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#if defined(cl_qcom_reqd_sub_group_size)
  4#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  5#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
  6#else
  7#define REQD_SUBGROUP_SIZE_128
  8#endif
  9
 10#define OPWM 64
 11#define OPWN 64
 12#define CPWK 8
 13#define OPTM 4
 14#define OPTN 8
 15
 16#define WG_M (OPWM / OPTM)
 17#define WG_N (OPWN / OPTN)
 18#define VEC_K (CPWK / 4)
 19
 20REQD_SUBGROUP_SIZE_128
 21__kernel void mul_mat_f16_f32(
 22    const int M, const int N, const int K,
 23    __global const void* A_void, ulong A_offset,
 24    __global const void* B_void, ulong B_offset,
 25    __global       void* C_void, ulong C_offset) {
 26
 27    __global const half*  A = (__global const half* )((__global const char*)A_void + A_offset);
 28    __global const float* B = (__global const float*)((__global const char*)B_void + B_offset);
 29    __global       float* C = (__global       float*)((__global       char*)C_void + C_offset);
 30
 31    const int lidm = get_local_id(0);
 32    const int lidn = get_local_id(1);
 33    const int lid = lidn * WG_M + lidm;
 34
 35    const int offsetM = get_group_id(0) * OPWM;
 36    const int offsetN = get_group_id(1) * OPWN;
 37
 38    __local half4  Alocal[OPWM][VEC_K];
 39    __local float4 Blocal[OPWN][VEC_K];
 40
 41    float sum[OPTM][OPTN];
 42
 43    for (int wm = 0; wm < OPTM; wm++) {
 44        for (int wn = 0; wn < OPTN; wn++) {
 45            sum[wm][wn] = 0.0f;
 46        }
 47    }
 48
 49    const int numTiles = (K + CPWK - 1) / CPWK;
 50
 51    const int load_row_a = lid % OPWM;
 52    const int load_vec_k_a = lid / OPWM;
 53    const int global_row_a = offsetM + load_row_a;
 54
 55    const int load_row_b = lid % OPWN;
 56    const int load_vec_k_b = lid / OPWN;
 57    const int global_row_b = offsetN + load_row_b;
 58
 59    for (int t = 0; t < numTiles; t++) {
 60        const int k_start = t * CPWK;
 61        const int k_vec_start_a = k_start + load_vec_k_a * 4;
 62        const int k_vec_start_b = k_start + load_vec_k_b * 4;
 63
 64        if (global_row_a < M && k_vec_start_a < K) {
 65            if (k_vec_start_a + 3 < K) {
 66                Alocal[load_row_a][load_vec_k_a] = vload4(0, A + global_row_a * K + k_vec_start_a);
 67            } else {
 68                half4 tempA = (half4)(0.0h);
 69                if (k_vec_start_a < K) tempA.s0 = A[global_row_a * K + k_vec_start_a];
 70                if (k_vec_start_a + 1 < K) tempA.s1 = A[global_row_a * K + k_vec_start_a + 1];
 71                if (k_vec_start_a + 2 < K) tempA.s2 = A[global_row_a * K + k_vec_start_a + 2];
 72                Alocal[load_row_a][load_vec_k_a] = tempA;
 73            }
 74        } else {
 75            Alocal[load_row_a][load_vec_k_a] = (half4)(0.0h);
 76        }
 77
 78        if (global_row_b < N && k_vec_start_b < K) {
 79            if (k_vec_start_b + 3 < K) {
 80                Blocal[load_row_b][load_vec_k_b] = vload4(0, B + global_row_b * K + k_vec_start_b);
 81            } else {
 82                float4 tempB = (float4)(0.0f);
 83                if (k_vec_start_b < K) tempB.s0 = B[global_row_b * K + k_vec_start_b];
 84                if (k_vec_start_b + 1 < K) tempB.s1 = B[global_row_b * K + k_vec_start_b + 1];
 85                if (k_vec_start_b + 2 < K) tempB.s2 = B[global_row_b * K + k_vec_start_b + 2];
 86                Blocal[load_row_b][load_vec_k_b] = tempB;
 87            }
 88        } else {
 89            Blocal[load_row_b][load_vec_k_b] = (float4)(0.0f);
 90        }
 91
 92        barrier(CLK_LOCAL_MEM_FENCE);
 93
 94        #pragma unroll
 95        for (int k_vec = 0; k_vec < VEC_K; k_vec++) {
 96            float4 a_fvecs[OPTM];
 97            int current_row_a = lidm;
 98            for (int wm = 0; wm < OPTM; wm++) {
 99                a_fvecs[wm] = convert_float4(Alocal[current_row_a][k_vec]);
100                current_row_a += WG_M;
101            }
102
103            float4 b_fvecs[OPTN];
104            int current_row_b = lidn;
105            for (int wn = 0; wn < OPTN; wn++) {
106                b_fvecs[wn] = Blocal[current_row_b][k_vec];
107                current_row_b += WG_N;
108            }
109
110            for (int wm = 0; wm < OPTM; wm++) {
111                for (int wn = 0; wn < OPTN; wn++) {
112                    sum[wm][wn] += dot(a_fvecs[wm], b_fvecs[wn]);
113                }
114            }
115        }
116        barrier(CLK_LOCAL_MEM_FENCE);
117    }
118
119    for (int wm = 0; wm < OPTM; wm++) {
120        int globalRow = offsetM + lidm + wm * WG_M;
121        if (globalRow < M) {
122            for (int wn = 0; wn < OPTN; wn++) {
123                int globalCol = offsetN + lidn + wn * WG_N;
124                if (globalCol < N) {
125                    C[globalCol * M + globalRow] = sum[wm][wn];
126                }
127            }
128        }
129    }
130}