1#include "common.cuh"
2#include "convert.cuh"
3#include "tri.cuh"
4#include "ggml.h"
5
6template<typename T, bool prefix_keep, int add_to_split>
7static __global__ void tri_kernel(
8 const T * src, T * dst,
9 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
10 const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
11 const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
12 const int64_t i3 = blockIdx.z;
13 const int64_t i2 = blockIdx.y;
14 const int64_t i1 = blockIdx.x;
15 const int64_t split_point = i1 + add_to_split;
16
17 GGML_UNUSED_VARS(nb00, nb0);
18
19 if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
20 return;
21 }
22
23 const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
24 T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3;
25
26 if constexpr (prefix_keep) {
27 for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
28 dst_row[i0] = src_row[i0];
29 }
30 for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
31 dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
32 }
33 } else {
34 for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
35 dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
36 }
37 for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
38 dst_row[i0] = src_row[i0];
39 }
40 }
41}
42
43template<typename T>
44static void tri_cuda(
45 const T * src, T * dst,
46 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
47 const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
48 const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
49 const ggml_tri_type ttype,
50 cudaStream_t stream) {
51
52 dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
53 dim3 grid_dims(ne01, ne02, ne03);
54 const size_t type_size = sizeof(T);
55
56 const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
57 const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
58
59 if (prefix_keep) {
60 if (add_to_split == 0) {
61 tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
62 src, dst,
63 ne00, ne01, ne02, ne03,
64 nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
65 nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
66 );
67 } else { // only 0 and 1 supported
68 tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
69 src, dst,
70 ne00, ne01, ne02, ne03,
71 nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
72 nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
73 );
74 }
75 } else {
76 if (add_to_split == 0) {
77 tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
78 src, dst,
79 ne00, ne01, ne02, ne03,
80 nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
81 nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
82 );
83 } else {
84 tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
85 src, dst,
86 ne00, ne01, ne02, ne03,
87 nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
88 nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
89 );
90 }
91 }
92}
93
94void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
95 const ggml_tensor * src0 = dst->src[0];
96 cudaStream_t stream = ctx.stream();
97
98 const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
99
100 GGML_ASSERT(src0->type == dst->type);
101
102 switch(src0->type) {
103 case GGML_TYPE_F32:
104 {
105 tri_cuda(
106 (const float *)src0->data, (float *)dst->data,
107 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
108 src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
109 dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
110 ttype, stream
111 );
112 } break;
113 case GGML_TYPE_F16:
114 {
115 tri_cuda(
116 (const half *)src0->data, (half *)dst->data,
117 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
118 src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
119 dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
120 ttype, stream
121 );
122 } break;
123 case GGML_TYPE_BF16:
124 {
125 tri_cuda(
126 (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
127 src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
128 src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
129 dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
130 ttype, stream
131 );
132 } break;
133 default:
134 GGML_ABORT("fatal error");
135 }
136}