1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3// 16-bit transpose, loading/storing a 4x4 tile of elements
4kernel void kernel_transpose_16(
5 __read_only image1d_buffer_t input,
6 __write_only image1d_buffer_t output,
7 const uint rows,
8 const uint cols
9) {
10
11 const int i = get_global_id(0);
12 const int j = get_global_id(1);
13 const int i_2 = i<<2;
14 const int j_2 = j<<2;
15
16 half4 temp0 = read_imageh(input, (j_2+0)*cols+i);
17 half4 temp1 = read_imageh(input, (j_2+1)*cols+i);
18 half4 temp2 = read_imageh(input, (j_2+2)*cols+i);
19 half4 temp3 = read_imageh(input, (j_2+3)*cols+i);
20
21 write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
22 write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
23 write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
24 write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
25}
26
27// Padded kernel for irregular shape
28kernel void kernel_transpose_16_4x1(
29 __read_only image1d_buffer_t input,
30 __write_only image1d_buffer_t output,
31 const uint rows,
32 const uint cols
33) {
34
35 const int i = get_global_id(0);
36 const int j = get_global_id(1);
37 const int j_2 = j << 2;
38
39 half temp0 = read_imageh(input, (j_2 + 0) * cols + i).x;
40 half temp1 = read_imageh(input, (j_2 + 1) * cols + i).x;
41 half temp2 = read_imageh(input, (j_2 + 2) * cols + i).x;
42 half temp3 = read_imageh(input, (j_2 + 3) * cols + i).x;
43
44 write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3));
45}
46
47// Transpose treating each element as 16-bit using buffer
48kernel void kernel_transpose_16_buf(
49 global const ushort * input,
50 global ushort * output,
51 const int ldi,
52 const int ldo
53) {
54 const int x = get_global_id(0);
55 const int y = get_global_id(1);
56
57 output[x*ldo + y] = input[y*ldi + x];
58}
59
60// 32-bit transpose, loading/storing a 4x4 tile of elements
61kernel void kernel_transpose_32(
62 __read_only image1d_buffer_t input,
63 __write_only image1d_buffer_t output,
64 const uint rows,
65 const uint cols
66) {
67
68 const int i = get_global_id(0);
69 const int j = get_global_id(1);
70 const int i_2 = i<<2;
71 const int j_2 = j<<2;
72
73 float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
74 float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
75 float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
76 float4 temp3 = read_imagef(input, (j_2+3)*cols+i);
77
78 write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
79 write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
80 write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
81 write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
82
83}
84
85// 32-bit transpose, loading/storing a 4x4 tile of elements
86// Only used for activations
87// converts to FP16
88// also adds zero padding for non multiple of 8 prompt lengths
89kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {
90
91 const int i = get_global_id(0);
92 const int j = get_global_id(1);
93 const int i_2 = i<<2;
94 const int j_2 = j<<2;
95 half4 temp0 = {0,0,0,0}; // initialize outputs to 0
96 half4 temp1 = {0,0,0,0};
97 half4 temp2 = {0,0,0,0};
98 half4 temp3 = {0,0,0,0};
99
100 if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
101 temp0 = read_imageh(input, (j_2+0)*cols+i);
102 }
103 if((j_2+1)*cols+i*4+3 < rows*cols*16){
104 temp1 = read_imageh(input, (j_2+1)*cols+i);
105 }
106 if((j_2+2)*cols+i*4+3 < rows*cols*16){
107 temp2 = read_imageh(input, (j_2+2)*cols+i);
108 }
109 if((j_2+3)*cols+i*4+3 < rows*cols*16){
110 temp3 = read_imageh(input, (j_2+3)*cols+i);
111 }
112
113 write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
114 write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
115 write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
116 write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
117}