1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3//------------------------------------------------------------------------------
4// cpy
5//------------------------------------------------------------------------------
6
7kernel void kernel_cpy_f16_f16(
8 global half * src0,
9 ulong offset0,
10 global half * dst,
11 ulong offsetd,
12 int ne00,
13 int ne01,
14 int ne02,
15 int ne03,
16 ulong nb00,
17 ulong nb01,
18 ulong nb02,
19 ulong nb03,
20 int ne0,
21 int ne1,
22 int ne2,
23 int ne3,
24 ulong nb0,
25 ulong nb1,
26 ulong nb2,
27 ulong nb3
28) {
29 src0 = (global half*)((global char*)src0 + offset0);
30 dst = (global half*)((global char*)dst + offsetd);
31
32 int i03 = get_group_id(2);
33 int i02 = get_group_id(1);
34 int i01 = get_group_id(0);
35
36 int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
37
38 int i3 = n / (ne2*ne1*ne0);
39 int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
40 int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
41 int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
42
43 global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
44
45 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
46 global const half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
47 dst_data[i00] = src[0];
48 }
49}
50
51kernel void kernel_cpy_f16_f32(
52 global half * src0,
53 ulong offset0,
54 global float * dst,
55 ulong offsetd,
56 int ne00,
57 int ne01,
58 int ne02,
59 int ne03,
60 ulong nb00,
61 ulong nb01,
62 ulong nb02,
63 ulong nb03,
64 int ne0,
65 int ne1,
66 int ne2,
67 int ne3,
68 ulong nb0,
69 ulong nb1,
70 ulong nb2,
71 ulong nb3
72) {
73
74 src0 = (global half*)((global char*)src0 + offset0);
75 dst = (global float*)((global char*)dst + offsetd);
76
77 int i03 = get_group_id(2);
78 int i02 = get_group_id(1);
79 int i01 = get_group_id(0);
80
81 int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
82
83 int i3 = n / (ne2*ne1*ne0);
84 int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
85 int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
86 int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
87
88 global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
89
90 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
91 global half * src = (global half *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
92 dst_data[i00] = src[0];
93 }
94}
95
96kernel void kernel_cpy_f32_f16(
97 global float * src0,
98 ulong offset0,
99 global half * dst,
100 ulong offsetd,
101 int ne00,
102 int ne01,
103 int ne02,
104 int ne03,
105 ulong nb00,
106 ulong nb01,
107 ulong nb02,
108 ulong nb03,
109 int ne0,
110 int ne1,
111 int ne2,
112 int ne3,
113 ulong nb0,
114 ulong nb1,
115 ulong nb2,
116 ulong nb3
117) {
118 src0 = (global float*)((global char*)src0 + offset0);
119 dst = (global half*)((global char*)dst + offsetd);
120
121 int i03 = get_group_id(2);
122 int i02 = get_group_id(1);
123 int i01 = get_group_id(0);
124
125 int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
126
127 int i3 = n / (ne2*ne1*ne0);
128 int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
129 int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
130 int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
131
132 global half * dst_data = (global half *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
133
134 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
135 global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
136
137 dst_data[i00] = src[0];
138 }
139}
140
141kernel void kernel_cpy_f32_f32(
142 global float * src0,
143 ulong offset0,
144 global float * dst,
145 ulong offsetd,
146 int ne00,
147 int ne01,
148 int ne02,
149 int ne03,
150 ulong nb00,
151 ulong nb01,
152 ulong nb02,
153 ulong nb03,
154 int ne0,
155 int ne1,
156 int ne2,
157 int ne3,
158 ulong nb0,
159 ulong nb1,
160 ulong nb2,
161 ulong nb3
162) {
163 src0 = (global float*)((global char*)src0 + offset0);
164 dst = (global float*)((global char*)dst + offsetd);
165
166 int i03 = get_group_id(2);
167 int i02 = get_group_id(1);
168 int i01 = get_group_id(0);
169
170 int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
171
172 int i3 = n / (ne2*ne1*ne0);
173 int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
174 int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
175 int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
176
177 global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
178
179 for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) {
180 global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
181
182 dst_data[i00] = src[0];
183 }
184}