1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3#define LOAD_VEC_A 4
  4#define LOAD_VEC_B 4
  5
  6#define BM 64
  7#define BN 64
  8#define BK 16
  9#define TM 4
 10#define TN 8
 11
 12kernel void kernel_mul_mm_f32_f32_l4_lm(
 13    global float4 * src0,
 14    ulong offset0,
 15    global float4 * src1,
 16    ulong offset1,
 17    global float * dst,
 18    ulong offsetd,
 19
 20    int ne00,
 21    int ne01,
 22    int ne02,
 23    int ne11,
 24    int ne12,
 25
 26    int stride_a,
 27    int stride_b,
 28    int stride_d,
 29
 30    int batch_stride_a,
 31    int batch_stride_b,
 32    int batch_stride_d,
 33
 34    int r2,
 35    int r3
 36) {
 37    src0 = (global float4*)((global char*)src0 + offset0);
 38    src1 = (global float4*)((global char*)src1 + offset1);
 39    dst = (global float*)((global char*)dst + offsetd);
 40
 41    local float buf_a[BM * BK];
 42    local float buf_b[BN * BK];
 43
 44    const int batch_idx = get_global_id(2);
 45
 46    const int i13 = batch_idx / ne12;
 47    const int i12 = batch_idx % ne12;
 48
 49    const int i03 = i13 / r3;
 50    const int i02 = i12 / r2;
 51
 52    const int batch_idx_a = i03 * ne02 + i02;
 53
 54    const int ir = get_group_id(0);
 55    const int ic = get_group_id(1);
 56
 57    const int tid = get_local_id(0);
 58    const int th_r  = tid % (BM / TM);
 59    const int th_c  = tid / (BM / TM);
 60
 61    const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
 62    const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
 63    const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
 64    const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
 65
 66    const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
 67    const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
 68
 69    int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
 70    int pos_b = (batch_idx   * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
 71
 72    float sums[TM * TN];
 73    float cache_a[TM];
 74    float cache_b[TN];
 75
 76    for (int i = 0; i < TM * TN; i++) {
 77        sums[i] = 0.0f;
 78    }
 79
 80    for (int block = 0; block < ne00; block += BK) {
 81        for (int l = 0; l < BM; l += loadstride_a) {
 82            if (ir*BM + loadc_a + l < ne01) {
 83                const int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
 84                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = src0[idx].s0;
 85                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = src0[idx].s1;
 86                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = src0[idx].s2;
 87                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = src0[idx].s3;
 88            } else {
 89                buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
 90                buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
 91                buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f;
 92                buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f;
 93            }
 94        }
 95
 96        for (int l = 0; l < BN; l += loadstride_b) {
 97            if (ic*BN + loadc_b + l < ne11) {
 98                const int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
 99                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
100                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
101                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
102                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
103            } else {
104                buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
105                buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
106                buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
107                buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
108            }
109        }
110
111        barrier(CLK_LOCAL_MEM_FENCE);
112
113        pos_a += BK / LOAD_VEC_A;
114        pos_b += BK / LOAD_VEC_B;
115
116        for (int i = 0; i < BK; i++) {
117            for (int j = 0; j < TM; j++) {
118                cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
119            }
120
121            for (int j = 0; j < TN; j++) {
122                cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
123            }
124
125            for (int cc = 0; cc < TN; cc++) {
126                for (int cr = 0; cr < TM; cr++) {
127                    const int sums_idx = cc*TM + cr;
128                    sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
129                }
130            }
131        }
132        barrier(CLK_LOCAL_MEM_FENCE);
133    }
134
135    const int dr = ir * BM + th_r * TM;
136    const int dc = ic * BN + th_c * TN;
137
138    const int offsets = batch_idx * batch_stride_d;
139
140    for (int cc = 0; cc < TN; cc++) {
141        for (int cr = 0; cr < TM; cr++) {
142            if (dr + cr < ne01 && dc + cc < ne11) {
143                dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
144            }
145        }
146    }
147}