1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3//------------------------------------------------------------------------------
  4// mul
  5//------------------------------------------------------------------------------
  6kernel void kernel_mul(
  7        global char * src0,
  8        ulong offset0,
  9        global char * src1,
 10        ulong offset1,
 11        global char * dst,
 12        ulong offsetd,
 13        int ne00,
 14        int ne01,
 15        int ne02,
 16        int ne03,
 17        ulong nb00,
 18        ulong nb01,
 19        ulong nb02,
 20        ulong nb03,
 21        int ne10,
 22        int ne11,
 23        int ne12,
 24        int ne13,
 25        ulong nb10,
 26        ulong nb11,
 27        ulong nb12,
 28        ulong nb13,
 29        int ne0,
 30        int ne1,
 31        int ne2,
 32        int ne3,
 33        ulong nb0,
 34        ulong nb1,
 35        ulong nb2,
 36        ulong nb3
 37) {
 38    src0 = src0 + offset0;
 39    src1 = src1 + offset1;
 40    dst  = dst + offsetd;
 41
 42    int i03 = get_group_id(2);
 43    int i02 = get_group_id(1);
 44    int i01 = get_group_id(0);
 45
 46    int i13 = i03 % ne13;
 47    int i12 = i02 % ne12;
 48    int i11 = i01 % ne11;
 49
 50    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
 51    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
 52    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 53
 54    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
 55        const int i10 = i0 % ne10;
 56        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) * *((global float *)(src1_ptr + i10*nb10));
 57    }
 58}
 59
 60// assumption: src1 is a row
 61// broadcast src1 into src0
 62kernel void kernel_mul_row(
 63        global float4 * src0,
 64        ulong offset0,
 65        global float4 * src1,
 66        ulong offset1,
 67        global float4 * dst,
 68        ulong offsetd,
 69        int ne
 70) {
 71    src0 = (global float4*)((global char*)src0 + offset0);
 72    src1 = (global float4*)((global char*)src1 + offset1);
 73    dst = (global float4*)((global char*)dst + offsetd);
 74
 75    // This performs better than using %.
 76    uint gid = get_global_id(0);
 77    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
 78    dst[gid] = src0[gid] * src1[idx1];
 79}
 80
 81kernel void kernel_mul_f16(
 82        global char * src0,
 83        ulong offset0,
 84        global char * src1,
 85        ulong offset1,
 86        global char * dst,
 87        ulong offsetd,
 88        int ne00,
 89        int ne01,
 90        int ne02,
 91        int ne03,
 92        ulong nb00,
 93        ulong nb01,
 94        ulong nb02,
 95        ulong nb03,
 96        int ne10,
 97        int ne11,
 98        int ne12,
 99        int ne13,
100        ulong nb10,
101        ulong nb11,
102        ulong nb12,
103        ulong nb13,
104        int ne0,
105        int ne1,
106        int ne2,
107        int ne3,
108        ulong nb0,
109        ulong nb1,
110        ulong nb2,
111        ulong nb3
112) {
113    src0 = src0 + offset0;
114    src1 = src1 + offset1;
115    dst  = dst + offsetd;
116
117    int i03 = get_group_id(2);
118    int i02 = get_group_id(1);
119    int i01 = get_group_id(0);
120
121    int i13 = i03 % ne13;
122    int i12 = i02 % ne12;
123    int i11 = i01 % ne11;
124
125    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
126    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
127    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
128
129    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
130        const int i10 = i0 % ne10;
131        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) * *((global half *)(src1_ptr + i10*nb10));
132    }
133}
134
135kernel void kernel_mul_row_f16(
136        global half4 * src0,
137        ulong offset0,
138        global half4 * src1,
139        ulong offset1,
140        global half4 * dst,
141        ulong offsetd,
142        int ne
143) {
144    src0 = (global half4*)((global char*)src0 + offset0);
145    src1 = (global half4*)((global char*)src1 + offset1);
146    dst = (global half4*)((global char*)dst + offsetd);
147
148    // This performs better than using %.
149    uint gid = get_global_id(0);
150    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
151    dst[gid] = src0[gid] * src1[idx1];
152}