1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3// v = { mp, L, d }
4inline uint fastdiv(uint n, uint4 v) {
5 uint msbs;
6 msbs = mul_hi(n, v.s0);
7 return (msbs + n) >> v.s1;
8}
9inline uint fastmod(uint n, uint4 v) {
10 uint q = fastdiv(n, v);
11 return n - q * v.s2;
12}
13
14kernel void kernel_set_rows_f32_i64(
15 global char * src0,
16 ulong offset0,
17 global char * src1,
18 ulong offset1,
19 global char * dst,
20 ulong offsetd,
21 int ne01,
22 ulong nb01,
23 ulong nb02,
24 ulong nb03,
25 uint4 ne11,
26 uint4 ne12,
27 ulong nb10,
28 ulong nb11,
29 ulong nb12,
30 int nblk0,
31 ulong nb1,
32 ulong nb2,
33 ulong nb3
34) {
35 src0 = src0 + offset0;
36 src1 = src1 + offset1;
37 dst = dst + offsetd;
38
39 int i03 = get_group_id(2);
40 int i02 = get_group_id(1);
41 int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
42
43 if (i01 >= ne01) {
44 return;
45 }
46
47 //int i12 = i03%ne12;
48 //int i11 = i02%ne11;
49 int i12 = fastmod(i03, ne12);
50 int i11 = fastmod(i02, ne11);
51
52 int i10 = i01;
53 long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
54
55 global float * dst_row = (global float *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
56 global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
57
58 for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
59 dst_row[ind] = (float)src_row[ind];
60 }
61}
62
63kernel void kernel_set_rows_f16_i64(
64 global char * src0,
65 ulong offset0,
66 global char * src1,
67 ulong offset1,
68 global char * dst,
69 ulong offsetd,
70 int ne01,
71 ulong nb01,
72 ulong nb02,
73 ulong nb03,
74 uint4 ne11,
75 uint4 ne12,
76 ulong nb10,
77 ulong nb11,
78 ulong nb12,
79 int nblk0,
80 ulong nb1,
81 ulong nb2,
82 ulong nb3
83) {
84 src0 = src0 + offset0;
85 src1 = src1 + offset1;
86 dst = dst + offsetd;
87
88 int i03 = get_group_id(2);
89 int i02 = get_group_id(1);
90 int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
91
92 if (i01 >= ne01) {
93 return;
94 }
95
96 //int i12 = i03%ne12;
97 //int i11 = i02%ne11;
98 int i12 = fastmod(i03, ne12);
99 int i11 = fastmod(i02, ne11);
100
101 int i10 = i01;
102 long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
103
104 global half * dst_row = (global half *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
105 global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
106
107 for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
108 dst_row[ind] = src_row[ind];
109 }
110}
111
112kernel void kernel_set_rows_f32_i32(
113 global char * src0,
114 ulong offset0,
115 global char * src1,
116 ulong offset1,
117 global char * dst,
118 ulong offsetd,
119 int ne01,
120 ulong nb01,
121 ulong nb02,
122 ulong nb03,
123 uint4 ne11,
124 uint4 ne12,
125 ulong nb10,
126 ulong nb11,
127 ulong nb12,
128 int nblk0,
129 ulong nb1,
130 ulong nb2,
131 ulong nb3
132) {
133 src0 = src0 + offset0;
134 src1 = src1 + offset1;
135 dst = dst + offsetd;
136
137 int i03 = get_group_id(2);
138 int i02 = get_group_id(1);
139 int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
140
141 if (i01 >= ne01) {
142 return;
143 }
144
145 //int i12 = i03%ne12;
146 //int i11 = i02%ne11;
147 int i12 = fastmod(i03, ne12);
148 int i11 = fastmod(i02, ne11);
149
150 int i10 = i01;
151 int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
152
153 global float * dst_row = (global float *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
154 global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
155
156 for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
157 dst_row[ind] = (float)src_row[ind];
158 }
159}
160
161kernel void kernel_set_rows_f16_i32(
162 global char * src0,
163 ulong offset0,
164 global char * src1,
165 ulong offset1,
166 global char * dst,
167 ulong offsetd,
168 int ne01,
169 ulong nb01,
170 ulong nb02,
171 ulong nb03,
172 uint4 ne11,
173 uint4 ne12,
174 ulong nb10,
175 ulong nb11,
176 ulong nb12,
177 int nblk0,
178 ulong nb1,
179 ulong nb2,
180 ulong nb3
181) {
182 src0 = src0 + offset0;
183 src1 = src1 + offset1;
184 dst = dst + offsetd;
185
186 int i03 = get_group_id(2);
187 int i02 = get_group_id(1);
188 int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
189
190 if (i01 >= ne01) {
191 return;
192 }
193
194 //int i12 = i03%ne12;
195 //int i11 = i02%ne11;
196 int i12 = fastmod(i03, ne12);
197 int i11 = fastmod(i02, ne11);
198
199 int i10 = i01;
200 int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
201
202 global half * dst_row = (global half *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
203 global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
204
205 for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
206 dst_row[ind] = src_row[ind];
207 }
208}