#pragma OPENCL EXTENSION cl_khr_fp16 : enable #define LOAD_VEC_A 4 #define LOAD_VEC_B 4 #define BM 64 #define BN 64 #define BK 32 #define TM 4 #define TN 8 kernel void kernel_mul_mm_q8_0_f32_l4_lm( global char4 * src0_q, global half * src0_d, global float4 * src1, ulong offset1, global float * dst, ulong offsetd, int ne00, int ne01, int ne02, int ne11, int ne12, int stride_a, int stride_b, int stride_d, int batch_stride_a, int batch_stride_b, int batch_stride_d, int r2, int r3 ) { src1 = (global float4*)((global char*)src1 + offset1); dst = (global float *)((global char*)dst + offsetd); local float buf_a[BM * BK]; local float buf_b[BN * BK]; const int batch_idx = get_global_id(2); const int i13 = batch_idx / ne12; const int i12 = batch_idx % ne12; const int i03 = i13 / r3; const int i02 = i12 / r2; const int batch_idx_a = i03 * ne02 + i02; const int ir = get_group_id(0); const int ic = get_group_id(1); const int tid = get_local_id(0); const int th_r = tid % (BM / TM); const int th_c = tid / (BM / TM); const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; float sums[TM * TN]; float cache_a[TM]; float cache_b[TN]; for (int i = 0; i < TM * TN; i++) { sums[i] = 0.0f; } for (int block = 0; block < ne00; block += BK) { for (int l = 0; l < BM; l += loadstride_a) { if (ir*BM + loadc_a + l < ne01) { int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; int ib = idx / 8; int iqs = idx % 8; float d = (float)src0_d[ib]; global char4 * qs = src0_q + ib*8 + iqs; char4 q = *qs; float4 v = convert_float4(q)*d; buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v.s0; buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v.s1; buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v.s2; buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v.s3; } else { buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; } } for (int l = 0; l < BN; l += loadstride_b) { if (ic*BN + loadc_b + l < ne11) { int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; } else { buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; } } barrier(CLK_LOCAL_MEM_FENCE); pos_a += BK / LOAD_VEC_A; pos_b += BK / LOAD_VEC_B; for (int i = 0; i < BK; i++) { for (int j = 0; j < TM; j++) { cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; } for (int j = 0; j < TN; j++) { cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; } for (int cc = 0; cc < TN; cc++) { for (int cr = 0; cr < TM; cr++) { const int sums_idx = cc*TM + cr; sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); } } } barrier(CLK_LOCAL_MEM_FENCE); } const int dr = ir * BM + th_r * TM; const int dc = ic * BN + th_c * TN; const int offsets = batch_idx * batch_stride_d; for (int cc = 0; cc < TN; cc++) { for (int cr = 0; cr < TM; cr++) { if (dr + cr < ne01 && dc + cc < ne11) { dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; } } } }