1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3#define LOAD_VEC_A 4
4#define LOAD_VEC_B 4
5
6#define BM 64
7#define BN 64
8#define BK 16
9#define TM 4
10#define TN 8
11
12kernel void kernel_mul_mm_f32_f32_l4_lm(
13 global float4 * src0,
14 ulong offset0,
15 global float4 * src1,
16 ulong offset1,
17 global float * dst,
18 ulong offsetd,
19
20 int ne00,
21 int ne01,
22 int ne02,
23 int ne11,
24 int ne12,
25
26 int stride_a,
27 int stride_b,
28 int stride_d,
29
30 int batch_stride_a,
31 int batch_stride_b,
32 int batch_stride_d,
33
34 int r2,
35 int r3
36) {
37 src0 = (global float4*)((global char*)src0 + offset0);
38 src1 = (global float4*)((global char*)src1 + offset1);
39 dst = (global float*)((global char*)dst + offsetd);
40
41 local float buf_a[BM * BK];
42 local float buf_b[BN * BK];
43
44 const int batch_idx = get_global_id(2);
45
46 const int i13 = batch_idx / ne12;
47 const int i12 = batch_idx % ne12;
48
49 const int i03 = i13 / r3;
50 const int i02 = i12 / r2;
51
52 const int batch_idx_a = i03 * ne02 + i02;
53
54 const int ir = get_group_id(0);
55 const int ic = get_group_id(1);
56
57 const int tid = get_local_id(0);
58 const int th_r = tid % (BM / TM);
59 const int th_c = tid / (BM / TM);
60
61 const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
62 const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
63 const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
64 const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
65
66 const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
67 const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
68
69 int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
70 int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
71
72 float sums[TM * TN];
73 float cache_a[TM];
74 float cache_b[TN];
75
76 for (int i = 0; i < TM * TN; i++) {
77 sums[i] = 0.0f;
78 }
79
80 for (int block = 0; block < ne00; block += BK) {
81 for (int l = 0; l < BM; l += loadstride_a) {
82 if (ir*BM + loadc_a + l < ne01) {
83 const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
84 buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
85 buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
86 buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
87 buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
88 } else {
89 buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
90 buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
91 buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
92 buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
93 }
94 }
95
96 for (int l = 0; l < BN; l += loadstride_b) {
97 if (ic*BN + loadc_b + l < ne11) {
98 const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
99 buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100 buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
101 buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
102 buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
103 } else {
104 buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
105 buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
106 buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
107 buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
108 }
109 }
110
111 barrier(CLK_LOCAL_MEM_FENCE);
112
113 pos_a += BK / LOAD_VEC_A;
114 pos_b += BK / LOAD_VEC_B;
115
116 for (int i = 0; i < BK; i++) {
117 for (int j = 0; j < TM; j++) {
118 cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
119 }
120
121 for (int j = 0; j < TN; j++) {
122 cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
123 }
124
125 for (int cc = 0; cc < TN; cc++) {
126 for (int cr = 0; cr < TM; cr++) {
127 const int sums_idx = cc*TM + cr;
128 sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
129 }
130 }
131 }
132 barrier(CLK_LOCAL_MEM_FENCE);
133 }
134
135 const int dr = ir * BM + th_r * TM;
136 const int dc = ic * BN + th_c * TN;
137
138 const int offsets = batch_idx * batch_stride_d;
139
140 for (int cc = 0; cc < TN; cc++) {
141 for (int cr = 0; cr < TM; cr++) {
142 if (dr + cr < ne01 && dc + cc < ne11) {
143 dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
144 }
145 }
146 }
147}