1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3//------------------------------------------------------------------------------
  4// div
  5//------------------------------------------------------------------------------
  6kernel void kernel_div(
  7        global char * src0,
  8        ulong offset0,
  9        global char * src1,
 10        ulong offset1,
 11        global char * dst,
 12        ulong offsetd,
 13        ulong nb00,
 14        ulong nb01,
 15        ulong nb02,
 16        ulong nb03,
 17        int ne10,
 18        int ne11,
 19        int ne12,
 20        int ne13,
 21        ulong nb10,
 22        ulong nb11,
 23        ulong nb12,
 24        ulong nb13,
 25        int ne0,
 26        ulong nb0,
 27        ulong nb1,
 28        ulong nb2,
 29        ulong nb3
 30) {
 31    src0 = src0 + offset0;
 32    src1 = src1 + offset1;
 33    dst  = dst + offsetd;
 34
 35    int i03 = get_group_id(2);
 36    int i02 = get_group_id(1);
 37    int i01 = get_group_id(0);
 38
 39    int i13 = i03 % ne13;
 40    int i12 = i02 % ne12;
 41    int i11 = i01 % ne11;
 42
 43    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
 44    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
 45    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
 46
 47    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
 48        const int i10 = i0 % ne10;
 49        *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i0*nb00)) / *((global float *)(src1_ptr + i10*nb10));
 50    }
 51}
 52
 53// assumption: src1 is a row
 54// broadcast src1 into src0
 55kernel void kernel_div_row(
 56        global float4 * src0,
 57        ulong offset0,
 58        global float4 * src1,
 59        ulong offset1,
 60        global float4 * dst,
 61        ulong offsetd,
 62        int ne
 63) {
 64    src0 = (global float4*)((global char*)src0 + offset0);
 65    src1 = (global float4*)((global char*)src1 + offset1);
 66    dst = (global float4*)((global char*)dst + offsetd);
 67
 68    // This performs better than using %.
 69    uint gid = get_global_id(0);
 70    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
 71    dst[gid] = src0[gid] / src1[idx1];
 72}
 73
 74kernel void kernel_div_f16(
 75        global char * src0,
 76        ulong offset0,
 77        global char * src1,
 78        ulong offset1,
 79        global char * dst,
 80        ulong offsetd,
 81        ulong nb00,
 82        ulong nb01,
 83        ulong nb02,
 84        ulong nb03,
 85        int ne10,
 86        int ne11,
 87        int ne12,
 88        int ne13,
 89        ulong nb10,
 90        ulong nb11,
 91        ulong nb12,
 92        ulong nb13,
 93        int ne0,
 94        ulong nb0,
 95        ulong nb1,
 96        ulong nb2,
 97        ulong nb3
 98) {
 99    src0 = src0 + offset0;
100    src1 = src1 + offset1;
101    dst  = dst + offsetd;
102
103    int i03 = get_group_id(2);
104    int i02 = get_group_id(1);
105    int i01 = get_group_id(0);
106
107    int i13 = i03 % ne13;
108    int i12 = i02 % ne12;
109    int i11 = i01 % ne11;
110
111    global char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
112    global char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
113    global char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
114
115    for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
116        const int i10 = i0 % ne10;
117        *((global half *)(dst_ptr + i0*nb0)) = *((global half *)(src0_ptr + i0*nb00)) / *((global half *)(src1_ptr + i10*nb10));
118    }
119}
120
121kernel void kernel_div_row_f16(
122        global half4 * src0,
123        ulong offset0,
124        global half4 * src1,
125        ulong offset1,
126        global half4 * dst,
127        ulong offsetd,
128        int ne
129) {
130    src0 = (global half4*)((global char*)src0 + offset0);
131    src1 = (global half4*)((global char*)src1 + offset1);
132    dst = (global half4*)((global char*)dst + offsetd);
133
134    // This performs better than using %.
135    uint gid = get_global_id(0);
136    uint idx1 = gid - (gid/ne)*ne; // get_global_id(0) % ne
137    dst[gid] = src0[gid] / src1[idx1];
138}