aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/tsembd.cl
blob: 21444bd958298ee9e1d3ebe613892e9655be57bd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
kernel void kernel_timestep_embedding(
    global const void * p_timesteps,
    ulong off_timesteps,
    global void * p_dst,
    ulong off_dst,
    int dst_nb1_bytes,
    int logical_dim,
    int max_period
) {
    int local_i;
    int local_j;
    int local_half_dim;
    float local_timestep_val;
    float local_freq;
    float local_arg;
    global float * local_embed_data_ptr;
    global const float * local_timesteps_input_ptr;
    global float * local_dst_output_base_ptr;

    local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps);
    local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst);

    local_i = get_global_id(1);
    local_j = get_global_id(0);

    local_half_dim = logical_dim / 2;
    local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);

    if (logical_dim % 2 != 0 && local_j == local_half_dim) {
        local_embed_data_ptr[2 * local_half_dim] = 0.0f;
    }

    if (local_j >= local_half_dim) {
        return;
    }

    local_timestep_val = local_timesteps_input_ptr[local_i];

    if (local_half_dim == 0) {
        local_freq = 1.0f;
    } else {
        local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim);
    }

    local_arg = local_timestep_val * local_freq;
    local_embed_data_ptr[local_j] = cos(local_arg);
    local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg);
}