1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3// 16-bit transpose, loading/storing a 4x4 tile of elements
  4kernel void kernel_transpose_16(
  5    __read_only image1d_buffer_t input,
  6    __write_only image1d_buffer_t output,
  7    const uint rows,
  8    const uint cols
  9) {
 10
 11    const int i = get_global_id(0);
 12    const int j = get_global_id(1);
 13    const int i_2 = i<<2;
 14    const int j_2 = j<<2;
 15
 16    half4 temp0 = read_imageh(input, (j_2+0)*cols+i);
 17    half4 temp1 = read_imageh(input, (j_2+1)*cols+i);
 18    half4 temp2 = read_imageh(input, (j_2+2)*cols+i);
 19    half4 temp3 = read_imageh(input, (j_2+3)*cols+i);
 20
 21    write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
 22    write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
 23    write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
 24    write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
 25}
 26
 27// Padded kernel for irregular shape
 28kernel void kernel_transpose_16_4x1(
 29    __read_only image1d_buffer_t input,
 30    __write_only image1d_buffer_t output,
 31    const uint rows,
 32    const uint cols
 33) {
 34
 35    const int i = get_global_id(0);
 36    const int j = get_global_id(1);
 37    const int j_2 = j << 2;
 38
 39    half temp0 = read_imageh(input, (j_2 + 0) * cols + i).x;
 40    half temp1 = read_imageh(input, (j_2 + 1) * cols + i).x;
 41    half temp2 = read_imageh(input, (j_2 + 2) * cols + i).x;
 42    half temp3 = read_imageh(input, (j_2 + 3) * cols + i).x;
 43
 44    write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3));
 45}
 46
 47// Transpose treating each element as 16-bit using buffer
 48kernel void kernel_transpose_16_buf(
 49    global const ushort * input,
 50    global ushort * output,
 51    const int ldi,
 52    const int ldo
 53) {
 54    const int x = get_global_id(0);
 55    const int y = get_global_id(1);
 56
 57    output[x*ldo + y] = input[y*ldi + x];
 58}
 59
 60// 32-bit transpose, loading/storing a 4x4 tile of elements
 61kernel void kernel_transpose_32(
 62    __read_only image1d_buffer_t input,
 63    __write_only image1d_buffer_t output,
 64    const uint rows,
 65    const uint cols
 66) {
 67
 68    const int i = get_global_id(0);
 69    const int j = get_global_id(1);
 70    const int i_2 = i<<2;
 71    const int j_2 = j<<2;
 72
 73    float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
 74    float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
 75    float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
 76    float4 temp3 = read_imagef(input, (j_2+3)*cols+i);
 77
 78    write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
 79    write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
 80    write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
 81    write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
 82
 83}
 84
 85// 32-bit transpose, loading/storing a 4x4 tile of elements
 86// Only used for activations
 87// converts to FP16
 88// also adds zero padding for non multiple of 8 prompt lengths
 89kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {
 90
 91    const int i = get_global_id(0);
 92    const int j = get_global_id(1);
 93    const int i_2 = i<<2;
 94    const int j_2 = j<<2;
 95    half4 temp0 = {0,0,0,0}; // initialize outputs to 0
 96    half4 temp1 = {0,0,0,0};
 97    half4 temp2 = {0,0,0,0};
 98    half4 temp3 = {0,0,0,0};
 99
100    if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
101        temp0 = read_imageh(input, (j_2+0)*cols+i);
102    }
103    if((j_2+1)*cols+i*4+3 < rows*cols*16){
104        temp1 = read_imageh(input, (j_2+1)*cols+i);
105    }
106    if((j_2+2)*cols+i*4+3 < rows*cols*16){
107        temp2 = read_imageh(input, (j_2+2)*cols+i);
108    }
109    if((j_2+3)*cols+i*4+3 < rows*cols*16){
110        temp3 = read_imageh(input, (j_2+3)*cols+i);
111    }
112
113    write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
114    write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
115    write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
116    write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
117}