aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cuda/set.cu
blob: 04bfe07ba033669f7cf2ce84abbf8a52fff48c78 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include "set.cuh"
#include "cpy.cuh"

void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
    const ggml_tensor * src0 = dst->src[0];
    const ggml_tensor * src1 = dst->src[1];

    GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
    GGML_ASSERT(src1->type == src0->type);
    GGML_ASSERT(dst ->type == src0->type);

    GGML_ASSERT(ggml_is_contiguous(dst));
    GGML_ASSERT(ggml_is_contiguous(src0));
    GGML_ASSERT(ggml_is_contiguous(src1));

    const size_t nb1    = ((int32_t *) dst->op_params)[0];
    const size_t nb2    = ((int32_t *) dst->op_params)[1];
    const size_t nb3    = ((int32_t *) dst->op_params)[2];
    const size_t offset = ((int32_t *) dst->op_params)[3];
    const bool   inplace= (bool)     ((int32_t *) dst->op_params)[4];

    if (!inplace) {
        ggml_cuda_cpy(ctx, src0, dst);
    }

    ggml_tensor dst_view = *dst;
    dst_view.data  = (void *)((char *)dst->data + offset);
    dst_view.ne[0] = src1->ne[0];
    dst_view.ne[1] = src1->ne[1];
    dst_view.ne[2] = src1->ne[2];
    dst_view.ne[3] = src1->ne[3];

    dst_view.nb[0] = ggml_element_size(dst);
    dst_view.nb[1] = nb1;
    dst_view.nb[2] = nb2;
    dst_view.nb[3] = nb3;

    ggml_cuda_cpy(ctx, src1, &dst_view);
}