aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/transpose.cl
blob: 1279b6531b949ee31c16440e7535618ef25dd3a6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#pragma OPENCL EXTENSION cl_khr_fp16 : enable

// 16-bit transpose, loading/storing a 4x4 tile of elements
kernel void kernel_transpose_16(
    __read_only image1d_buffer_t input,
    __write_only image1d_buffer_t output,
    const uint rows,
    const uint cols
) {

    const int i = get_global_id(0);
    const int j = get_global_id(1);
    const int i_2 = i<<2;
    const int j_2 = j<<2;

    half4 temp0 = read_imageh(input, (j_2+0)*cols+i);
    half4 temp1 = read_imageh(input, (j_2+1)*cols+i);
    half4 temp2 = read_imageh(input, (j_2+2)*cols+i);
    half4 temp3 = read_imageh(input, (j_2+3)*cols+i);

    write_imageh(output, (i_2+0)*rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
    write_imageh(output, (i_2+1)*rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
    write_imageh(output, (i_2+2)*rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
    write_imageh(output, (i_2+3)*rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
}

// Padded kernel for irregular shape
kernel void kernel_transpose_16_4x1(
    __read_only image1d_buffer_t input,
    __write_only image1d_buffer_t output,
    const uint rows,
    const uint cols
) {

    const int i = get_global_id(0);
    const int j = get_global_id(1);
    const int j_2 = j << 2;

    half temp0 = read_imageh(input, (j_2 + 0) * cols + i).x;
    half temp1 = read_imageh(input, (j_2 + 1) * cols + i).x;
    half temp2 = read_imageh(input, (j_2 + 2) * cols + i).x;
    half temp3 = read_imageh(input, (j_2 + 3) * cols + i).x;

    write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3));
}

// Transpose treating each element as 16-bit using buffer
kernel void kernel_transpose_16_buf(
    global const ushort * input,
    global ushort * output,
    const int ldi,
    const int ldo
) {
    const int x = get_global_id(0);
    const int y = get_global_id(1);

    output[x*ldo + y] = input[y*ldi + x];
}

// 32-bit transpose, loading/storing a 4x4 tile of elements
kernel void kernel_transpose_32(
    __read_only image1d_buffer_t input,
    __write_only image1d_buffer_t output,
    const uint rows,
    const uint cols
) {

    const int i = get_global_id(0);
    const int j = get_global_id(1);
    const int i_2 = i<<2;
    const int j_2 = j<<2;

    float4 temp0 = read_imagef(input, (j_2+0)*cols+i);
    float4 temp1 = read_imagef(input, (j_2+1)*cols+i);
    float4 temp2 = read_imagef(input, (j_2+2)*cols+i);
    float4 temp3 = read_imagef(input, (j_2+3)*cols+i);

    write_imagef(output, (i_2+0)*rows+j, (float4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0));
    write_imagef(output, (i_2+1)*rows+j, (float4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
    write_imagef(output, (i_2+2)*rows+j, (float4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
    write_imagef(output, (i_2+3)*rows+j, (float4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));

}

// 32-bit transpose, loading/storing a 4x4 tile of elements
// Only used for activations
// converts to FP16
// also adds zero padding for non multiple of 8 prompt lengths
kernel void kernel_transpose_32_16(__read_only image1d_buffer_t input, __write_only image1d_buffer_t output, const uint rows, const uint cols, const uint padded_rows) {

    const int i = get_global_id(0);
    const int j = get_global_id(1);
    const int i_2 = i<<2;
    const int j_2 = j<<2;
    half4 temp0 = {0,0,0,0}; // initialize outputs to 0
    half4 temp1 = {0,0,0,0};
    half4 temp2 = {0,0,0,0};
    half4 temp3 = {0,0,0,0};

    if((j_2+0)*cols+i*4+3 < rows*cols*16){ // only load from a valid location. Otherwise keep register data as 0
        temp0 = read_imageh(input, (j_2+0)*cols+i);
    }
    if((j_2+1)*cols+i*4+3 < rows*cols*16){
        temp1 = read_imageh(input, (j_2+1)*cols+i);
    }
    if((j_2+2)*cols+i*4+3 < rows*cols*16){
        temp2 = read_imageh(input, (j_2+2)*cols+i);
    }
    if((j_2+3)*cols+i*4+3 < rows*cols*16){
        temp3 = read_imageh(input, (j_2+3)*cols+i);
    }

    write_imageh(output, (i_2+0)*padded_rows+j, (half4)(temp0.s0, temp1.s0, temp2.s0, temp3.s0)); // no conditionals for output, includes zero padding
    write_imageh(output, (i_2+1)*padded_rows+j, (half4)(temp0.s1, temp1.s1, temp2.s1, temp3.s1));
    write_imageh(output, (i_2+2)*padded_rows+j, (half4)(temp0.s2, temp1.s2, temp2.s2, temp3.s2));
    write_imageh(output, (i_2+3)*padded_rows+j, (half4)(temp0.s3, temp1.s3, temp2.s3, temp3.s3));
}