1kernel void kernel_timestep_embedding(
2 global const void * p_timesteps,
3 ulong off_timesteps,
4 global void * p_dst,
5 ulong off_dst,
6 int dst_nb1_bytes,
7 int logical_dim,
8 int max_period
9) {
10 int local_i;
11 int local_j;
12 int local_half_dim;
13 float local_timestep_val;
14 float local_freq;
15 float local_arg;
16 global float * local_embed_data_ptr;
17 global const float * local_timesteps_input_ptr;
18 global float * local_dst_output_base_ptr;
19
20 local_timesteps_input_ptr = (global const float *)((global char *)p_timesteps + off_timesteps);
21 local_dst_output_base_ptr = (global float *)((global char *)p_dst + off_dst);
22
23 local_i = get_global_id(1);
24 local_j = get_global_id(0);
25
26 local_half_dim = logical_dim / 2;
27 local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);
28
29 if (logical_dim % 2 != 0 && local_j == local_half_dim) {
30 local_embed_data_ptr[2 * local_half_dim] = 0.0f;
31 }
32
33 if (local_j >= local_half_dim) {
34 return;
35 }
36
37 local_timestep_val = local_timesteps_input_ptr[local_i];
38
39 if (local_half_dim == 0) {
40 local_freq = 1.0f;
41 } else {
42 local_freq = exp(-log((float)max_period) * (float)local_j / (float)local_half_dim);
43 }
44
45 local_arg = local_timestep_val * local_freq;
46 local_embed_data_ptr[local_j] = cos(local_arg);
47 local_embed_data_ptr[local_j + local_half_dim] = sin(local_arg);
48}