1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3// v = { mp, L, d }
  4inline uint fastdiv(uint n, uint4 v) {
  5    uint msbs;
  6    msbs = mul_hi(n, v.s0);
  7    return (msbs + n) >> v.s1;
  8}
  9inline uint fastmod(uint n, uint4 v) {
 10    uint q = fastdiv(n, v);
 11    return n - q * v.s2;
 12}
 13
 14kernel void kernel_set_rows_f32_i64(
 15        global char * src0,
 16        ulong         offset0,
 17        global char * src1,
 18        ulong         offset1,
 19        global char * dst,
 20        ulong         offsetd,
 21        int           ne01,
 22        ulong         nb01,
 23        ulong         nb02,
 24        ulong         nb03,
 25        uint4         ne11,
 26        uint4         ne12,
 27        ulong         nb10,
 28        ulong         nb11,
 29        ulong         nb12,
 30        int           nblk0,
 31        ulong         nb1,
 32        ulong         nb2,
 33        ulong         nb3
 34) {
 35    src0 = src0 + offset0;
 36    src1 = src1 + offset1;
 37    dst  = dst  + offsetd;
 38
 39    int i03 = get_group_id(2);
 40    int i02 = get_group_id(1);
 41    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
 42
 43    if (i01 >= ne01) {
 44        return;
 45    }
 46
 47    //int i12 = i03%ne12;
 48    //int i11 = i02%ne11;
 49    int i12 = fastmod(i03, ne12);
 50    int i11 = fastmod(i02, ne11);
 51
 52    int i10 = i01;
 53    long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
 54
 55    global float * dst_row = (global float *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);
 56    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
 57
 58    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
 59        dst_row[ind] = (float)src_row[ind];
 60    }
 61}
 62
 63kernel void kernel_set_rows_f16_i64(
 64        global char * src0,
 65        ulong         offset0,
 66        global char * src1,
 67        ulong         offset1,
 68        global char * dst,
 69        ulong         offsetd,
 70        int           ne01,
 71        ulong         nb01,
 72        ulong         nb02,
 73        ulong         nb03,
 74        uint4         ne11,
 75        uint4         ne12,
 76        ulong         nb10,
 77        ulong         nb11,
 78        ulong         nb12,
 79        int           nblk0,
 80        ulong         nb1,
 81        ulong         nb2,
 82        ulong         nb3
 83) {
 84    src0 = src0 + offset0;
 85    src1 = src1 + offset1;
 86    dst  = dst  + offsetd;
 87
 88    int i03 = get_group_id(2);
 89    int i02 = get_group_id(1);
 90    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
 91
 92    if (i01 >= ne01) {
 93        return;
 94    }
 95
 96    //int i12 = i03%ne12;
 97    //int i11 = i02%ne11;
 98    int i12 = fastmod(i03, ne12);
 99    int i11 = fastmod(i02, ne11);
100
101    int i10 = i01;
102    long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
103
104    global half  * dst_row = (global half  *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);
105    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
106
107    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
108        dst_row[ind] = src_row[ind];
109    }
110}
111
112kernel void kernel_set_rows_f32_i32(
113        global char * src0,
114        ulong         offset0,
115        global char * src1,
116        ulong         offset1,
117        global char * dst,
118        ulong         offsetd,
119        int           ne01,
120        ulong         nb01,
121        ulong         nb02,
122        ulong         nb03,
123        uint4         ne11,
124        uint4         ne12,
125        ulong         nb10,
126        ulong         nb11,
127        ulong         nb12,
128        int           nblk0,
129        ulong         nb1,
130        ulong         nb2,
131        ulong         nb3
132) {
133    src0 = src0 + offset0;
134    src1 = src1 + offset1;
135    dst  = dst  + offsetd;
136
137    int i03 = get_group_id(2);
138    int i02 = get_group_id(1);
139    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
140
141    if (i01 >= ne01) {
142        return;
143    }
144
145    //int i12 = i03%ne12;
146    //int i11 = i02%ne11;
147    int i12 = fastmod(i03, ne12);
148    int i11 = fastmod(i02, ne11);
149
150    int i10 = i01;
151    int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
152
153    global float * dst_row = (global float *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);
154    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
155
156    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
157        dst_row[ind] = (float)src_row[ind];
158    }
159}
160
161kernel void kernel_set_rows_f16_i32(
162        global char * src0,
163        ulong         offset0,
164        global char * src1,
165        ulong         offset1,
166        global char * dst,
167        ulong         offsetd,
168        int           ne01,
169        ulong         nb01,
170        ulong         nb02,
171        ulong         nb03,
172        uint4         ne11,
173        uint4         ne12,
174        ulong         nb10,
175        ulong         nb11,
176        ulong         nb12,
177        int           nblk0,
178        ulong         nb1,
179        ulong         nb2,
180        ulong         nb3
181) {
182    src0 = src0 + offset0;
183    src1 = src1 + offset1;
184    dst  = dst  + offsetd;
185
186    int i03 = get_group_id(2);
187    int i02 = get_group_id(1);
188    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
189
190    if (i01 >= ne01) {
191        return;
192    }
193
194    //int i12 = i03%ne12;
195    //int i11 = i02%ne11;
196    int i12 = fastmod(i03, ne12);
197    int i11 = fastmod(i02, ne11);
198
199    int i10 = i01;
200    int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
201
202    global half  * dst_row = (global half  *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);
203    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
204
205    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
206        dst_row[ind] = src_row[ind];
207    }
208}