1#include "tsembd.cuh"
 2
 3static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
 4    // blockIDx.y: idx of timesteps->ne[0]
 5    // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
 6    int i = blockIdx.y;
 7    int j = threadIdx.x + blockIdx.x * blockDim.x;
 8    float * embed_data = (float *)((char *)dst +  i*nb1);
 9
10    int half = dim / 2;
11    if (dim % 2 != 0 && j == half) {
12        embed_data[2 * half] = 0.f;
13    }
14
15    if (j >= half) {
16        return;
17    }
18
19    float timestep = timesteps[i];
20    float freq = (float)expf(-logf(max_period) * j / half);
21    float arg = timestep * freq;
22    embed_data[j] = cosf(arg);
23    embed_data[j + half] = sinf(arg);
24}
25
26static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
27                                        const int dim, const int max_period, cudaStream_t stream) {
28    int half_ceil = (dim + 1) / 2;
29    int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
30    dim3 gridDim(num_blocks, ne00, 1);
31    timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
32}
33
34void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
35    const ggml_tensor * src0 = dst->src[0];
36    const float * src0_d = (const float *)src0->data;
37    float * dst_d = (float *)dst->data;
38    cudaStream_t stream = ctx.stream();
39
40    GGML_ASSERT(src0->type == GGML_TYPE_F32);
41    GGML_ASSERT(dst->type == GGML_TYPE_F32);
42
43    const int dim = dst->op_params[0];
44    const int max_period = dst->op_params[1];
45
46    timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
47}