1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3//------------------------------------------------------------------------------
  4// cpy
  5//------------------------------------------------------------------------------
  6
  7kernel void kernel_cpy_f16_f16(
  8        global half * src0,
  9        ulong offset0,
 10        global half * dst,
 11        ulong offsetd,
 12        int ne00,
 13        int ne01,
 14        int ne02,
 15        int ne03,
 16        ulong nb00,
 17        ulong nb01,
 18        ulong nb02,
 19        ulong nb03,
 20        int ne0,
 21        int ne1,
 22        int ne2,
 23        int ne3,
 24        ulong nb0,
 25        ulong nb1,
 26        ulong nb2,
 27        ulong nb3
 28) {
 29    src0 = (global half*)((global char*)src0 + offset0);
 30    dst = (global half*)((global char*)dst + offsetd);
 31
 32    int i03 = get_group_id(2);
 33    int i02 = get_group_id(1);
 34    int i01 = get_group_id(0);
 35
 36    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 37
 38    int i3 = n / (ne2*ne1*ne0);
 39    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
 40    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
 41    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
 42
 43    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 44
 45    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
 46        global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
 47        dst_data[i00] = src[0];
 48    }
 49}
 50
 51kernel void kernel_cpy_f16_f32(
 52        global half * src0,
 53        ulong offset0,
 54        global float * dst,
 55        ulong offsetd,
 56        int ne00,
 57        int ne01,
 58        int ne02,
 59        int ne03,
 60        ulong nb00,
 61        ulong nb01,
 62        ulong nb02,
 63        ulong nb03,
 64        int ne0,
 65        int ne1,
 66        int ne2,
 67        int ne3,
 68        ulong nb0,
 69        ulong nb1,
 70        ulong nb2,
 71        ulong nb3
 72) {
 73
 74    src0 = (global half*)((global char*)src0 + offset0);
 75    dst = (global float*)((global char*)dst + offsetd);
 76
 77    int i03 = get_group_id(2);
 78    int i02 = get_group_id(1);
 79    int i01 = get_group_id(0);
 80
 81    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 82
 83    int i3 = n / (ne2*ne1*ne0);
 84    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
 85    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
 86    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
 87
 88    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
 89
 90    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
 91        global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
 92        dst_data[i00] = src[0];
 93    }
 94}
 95
 96kernel void kernel_cpy_f32_f16(
 97        global float * src0,
 98        ulong offset0,
 99        global half * dst,
100        ulong offsetd,
101        int ne00,
102        int ne01,
103        int ne02,
104        int ne03,
105        ulong nb00,
106        ulong nb01,
107        ulong nb02,
108        ulong nb03,
109        int ne0,
110        int ne1,
111        int ne2,
112        int ne3,
113        ulong nb0,
114        ulong nb1,
115        ulong nb2,
116        ulong nb3
117) {
118    src0 = (global float*)((global char*)src0 + offset0);
119    dst = (global half*)((global char*)dst + offsetd);
120
121    int i03 = get_group_id(2);
122    int i02 = get_group_id(1);
123    int i01 = get_group_id(0);
124
125    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
126
127    int i3 = n / (ne2*ne1*ne0);
128    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
129    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
130    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
131
132    global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
133
134    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
135        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
136
137        dst_data[i00] = src[0];
138    }
139}
140
141kernel void kernel_cpy_f32_f32(
142        global float * src0,
143        ulong offset0,
144        global float * dst,
145        ulong offsetd,
146        int ne00,
147        int ne01,
148        int ne02,
149        int ne03,
150        ulong nb00,
151        ulong nb01,
152        ulong nb02,
153        ulong nb03,
154        int ne0,
155        int ne1,
156        int ne2,
157        int ne3,
158        ulong nb0,
159        ulong nb1,
160        ulong nb2,
161        ulong nb3
162) {
163    src0 = (global float*)((global char*)src0 + offset0);
164    dst = (global float*)((global char*)dst + offsetd);
165
166    int i03 = get_group_id(2);
167    int i02 = get_group_id(1);
168    int i01 = get_group_id(0);
169
170    int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
171
172    int i3 = n / (ne2*ne1*ne0);
173    int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
174    int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
175    int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
176
177    global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
178
179    for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
180        global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
181
182        dst_data[i00] = src[0];
183    }
184}