1#include "convert.cuh"
2#include "diag.cuh"
3#include "ggml.h"
4
5template <typename T>
6static __global__ void diag_kernel(T * __restrict__ dst,
7 const T * __restrict__ src,
8 const int64_t ne0,
9 const int64_t ne1,
10 const int64_t ne2,
11 const int64_t ne3,
12 const int64_t total_elements) {
13 const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
14
15 if (global_idx >= total_elements) {
16 return;
17 }
18
19 const int64_t i0 = global_idx % ne0;
20 const int64_t i1 = (global_idx / ne0) % ne1;
21 const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;
22 const int64_t i3 = global_idx / (ne0 * ne1 * ne2);
23
24 const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
25
26 if (i0 == i1) {
27 const int64_t batch_idx = i3 * ne2 + i2;
28 const int64_t src_idx = batch_idx * ne0 + i0;
29 dst[dst_idx] = src[src_idx];
30 } else {
31 dst[dst_idx] = ggml_cuda_cast<T>(0);
32 }
33 GGML_UNUSED_VARS(ne3);
34}
35
36void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
37 const ggml_tensor * src0 = dst->src[0];
38
39 void * dst_d = dst->data;
40 const void * src0_d = src0->data;
41
42 cudaStream_t stream = ctx.stream();
43
44 GGML_ASSERT(ggml_is_contiguous(dst));
45 GGML_ASSERT(ggml_is_contiguous(src0));
46
47 const int64_t ne00 = src0->ne[0];
48 const int64_t ne01 = src0->ne[1];
49 const int64_t ne02 = src0->ne[2];
50 const int64_t ne03 = src0->ne[3];
51
52 const int64_t ne0 = dst->ne[0];
53 const int64_t ne1 = dst->ne[1];
54 const int64_t ne2 = dst->ne[2];
55 const int64_t ne3 = dst->ne[3];
56
57 GGML_ASSERT(ne00 == ne0);
58 GGML_ASSERT(ne01 == 1);
59 GGML_ASSERT(ne02 == ne2);
60 GGML_ASSERT(ne03 == ne3);
61
62 const int64_t n_elems = ggml_nelements(dst);
63 const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;
64
65 switch (dst->type) {
66 case GGML_TYPE_F32:
67 diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,
68 ne1, ne2, ne3, n_elems);
69 break;
70 case GGML_TYPE_F16:
71 diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,
72 ne1, ne2, ne3, n_elems);
73 break;
74 default:
75 GGML_ABORT("unsupported type");
76 }
77}