1#include "common.cuh"
  2#include "convert.cuh"
  3#include "tri.cuh"
  4#include "ggml.h"
  5
  6template<typename T, bool prefix_keep, int add_to_split>
  7static __global__ void tri_kernel(
  8        const T * src, T * dst,
  9        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 10        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
 11        const int64_t nb0,  const int64_t nb1,  const int64_t nb2,  const int64_t nb3) {
 12    const int64_t i3 = blockIdx.z;
 13    const int64_t i2 = blockIdx.y;
 14    const int64_t i1 = blockIdx.x;
 15    const int64_t split_point = i1 + add_to_split;
 16
 17    GGML_UNUSED_VARS(nb00, nb0);
 18
 19    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
 20        return;
 21    }
 22
 23    const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
 24    T       * dst_row = dst + i1*nb1  + i2*nb2  + i3*nb3;
 25
 26    if constexpr (prefix_keep) {
 27        for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
 28            dst_row[i0] = src_row[i0];
 29        }
 30        for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
 31            dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
 32        }
 33    } else {
 34        for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
 35            dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
 36        }
 37        for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
 38            dst_row[i0] = src_row[i0];
 39        }
 40    }
 41}
 42
 43template<typename T>
 44static void tri_cuda(
 45        const T * src, T * dst,
 46        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 47        const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
 48        const int64_t nb0,  const int64_t nb1,  const int64_t nb2,  const int64_t nb3,
 49        const ggml_tri_type ttype,
 50        cudaStream_t stream) {
 51
 52    dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
 53    dim3 grid_dims(ne01, ne02, ne03);
 54    const size_t type_size = sizeof(T);
 55
 56    const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
 57    const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
 58
 59    if (prefix_keep) {
 60        if (add_to_split == 0) {
 61            tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
 62                src, dst,
 63                ne00, ne01, ne02, ne03,
 64                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
 65                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
 66            );
 67        } else { // only 0 and 1 supported
 68            tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
 69                src, dst,
 70                ne00, ne01, ne02, ne03,
 71                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
 72                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
 73            );
 74        }
 75    } else {
 76        if (add_to_split == 0) {
 77            tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
 78                src, dst,
 79                ne00, ne01, ne02, ne03,
 80                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
 81                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
 82            );
 83        } else {
 84            tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
 85                src, dst,
 86                ne00, ne01, ne02, ne03,
 87                nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
 88                nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
 89            );
 90        }
 91    }
 92}
 93
 94void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 95    const ggml_tensor * src0 = dst->src[0];
 96    cudaStream_t stream = ctx.stream();
 97
 98    const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
 99
100    GGML_ASSERT(src0->type == dst->type);
101
102    switch(src0->type) {
103        case GGML_TYPE_F32:
104            {
105                tri_cuda(
106                    (const float *)src0->data, (float *)dst->data,
107                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
108                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
109                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
110                    ttype, stream
111                );
112            } break;
113        case GGML_TYPE_F16:
114            {
115                tri_cuda(
116                    (const half *)src0->data, (half *)dst->data,
117                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
118                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
119                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
120                    ttype, stream
121                );
122            } break;
123        case GGML_TYPE_BF16:
124            {
125                tri_cuda(
126                    (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
127                    src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
128                    src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
129                    dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
130                    ttype, stream
131                );
132            } break;
133        default:
134            GGML_ABORT("fatal error");
135    }
136}