1#include "common.cuh"
  2#include "ggml.h"
  3#include "solve_tri.cuh"
  4
  5#define MAX_N_FAST 64
  6#define MAX_K_FAST 32
  7
  8static __global__ void get_batch_pointers(const float *  A,
  9                                          float *        X,
 10                                          const float ** A_ptrs,
 11                                          float **       X_ptrs,
 12                                          int64_t        ne02,
 13                                          int64_t        total_batches,
 14                                          size_t         s02,
 15                                          size_t         s03,
 16                                          size_t         s2,
 17                                          size_t         s3) {
 18    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
 19    if (idx >= total_batches) {
 20        return;
 21    }
 22
 23    const int64_t i3 = idx / ne02;
 24    const int64_t i2 = idx % ne02;
 25
 26    A_ptrs[idx] = A + i3 * s03 + i2 * s02;
 27    X_ptrs[idx] = X + i3 * s3 + i2 * s2;
 28}
 29
 30static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
 31                                 const float *               A,
 32                                 const float *               B,
 33                                 float *                     X,
 34                                 int                         n,
 35                                 int                         k,
 36                                 int64_t                     ne02,
 37                                 int64_t                     ne03,
 38                                 size_t                      s02,
 39                                 size_t                      s03,
 40                                 size_t                      s12,
 41                                 size_t                      s13,
 42                                 size_t                      s2,
 43                                 size_t                      s3,
 44                                 cudaStream_t                stream) {
 45    const float   alpha         = 1.0f;
 46    const int64_t total_batches = ne02 * ne03;
 47    if (total_batches == 0) {
 48        return;
 49    }
 50
 51    // Bulk copy B -> X (contiguous tensors)
 52    if (X != B) {
 53        const int64_t total_elements_BX = n * k * total_batches;
 54        CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
 55    }
 56
 57    const int id = ggml_cuda_get_device();
 58
 59    ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
 60    ggml_cuda_pool_alloc<float *>       X_ptrs_alloc(ctx.pool(id), total_batches);
 61
 62    const float ** A_ptrs_dev = A_ptrs_alloc.get();
 63    float **       X_ptrs_dev = X_ptrs_alloc.get();
 64
 65    get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
 66                                                                        total_batches, s02, s03, s2, s3);
 67
 68    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
 69
 70    // Yes, this is necessary, without this we get RMSE errors
 71    CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
 72    CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
 73                                    CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
 74
 75    // revert to standard mode from common.cuh
 76    CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
 77
 78    GGML_UNUSED_VARS(s12, s13);
 79}
 80
 81// ======================
 82// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
 83// ======================
 84// When ncols_template == 0 the bounds for the loops in this function are not
 85// known and can't be unrolled. As we want to keep pragma unroll for all other
 86// cases we supress the clang transformation warning here.
 87#ifdef __clang__
 88#    pragma clang diagnostic push
 89#    pragma clang diagnostic ignored "-Wpass-failed"
 90#endif  // __clang__
 91template <int n_template, int k_template>
 92static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
 93                                          const float * __restrict__ B,
 94                                          float * __restrict__ X,
 95                                          const uint3  ne02,
 96                                          const size_t nb02,
 97                                          const size_t nb03,
 98                                          const size_t nb12,
 99                                          const size_t nb13,
100                                          const size_t nb2,
101                                          const size_t nb3,
102                                          const int    n_arg,
103                                          const int    k_arg) {
104    const int n = n_template == 0 ? n_arg : n_template;
105    const int k = k_template == 0 ? k_arg : k_template;
106
107    const int batch_idx = blockIdx.x;
108    const int lane      = threadIdx.x;
109    const int col_idx   = threadIdx.y;
110
111    if (col_idx >= k) {
112        return;
113    }
114
115    const uint2   i02_i03 = fast_div_modulo(batch_idx, ne02);
116    const int64_t i02     = i02_i03.y;
117    const int64_t i03     = i02_i03.x;
118
119    const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
120    const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
121    float *             X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
122
123    __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
124
125    const int offset = threadIdx.x + threadIdx.y * blockDim.x;
126
127#pragma unroll
128    for (int i = 0; i < n * n; i += k * WARP_SIZE) {
129        const int i0 = i + offset;
130        if (i0 < n * n) {
131            sA[i0] = A_batch[i0];
132        }
133    }
134
135    __syncthreads();
136
137    float x_low  = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
138    float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
139
140    const int half      = WARP_SIZE;
141    const int nrows_low = (n < half) ? n : half;
142
143#pragma unroll
144    for (int row = 0; row < nrows_low; ++row) {
145        float sum = 0.0f;
146        if (lane < row) {
147            sum += sA[row * n + lane] * x_low;
148        }
149        sum = warp_reduce_sum(sum);
150
151        if (lane == row) {
152            x_low = (x_low - sum) / sA[row * n + row];
153        }
154    }
155
156#pragma unroll
157    for (int row = half; row < n; ++row) {
158        float     sum = sA[row * n + lane] * x_low;
159        const int j   = half + lane;
160        if (j < row) {
161            sum += sA[row * n + j] * x_high;
162        }
163        sum = warp_reduce_sum(sum);
164
165        if (lane == row - half) {
166            x_high = (x_high - sum) / sA[row * n + row];
167        }
168    }
169
170#pragma unroll
171    for (int rr = 0; rr < 2; ++rr) {
172        const int row = rr * WARP_SIZE + lane;
173        if (row < n) {
174            const float val            = (row < half) ? x_low : x_high;
175            X_batch[row * k + col_idx] = val;
176        }
177    }
178}
179#ifdef __clang__
180#    pragma clang diagnostic pop
181#endif  // __clang__
182
183static void solve_tri_f32_cuda(const float * A,
184                               const float * B,
185                               float *       X,
186                               int           n,
187                               int           k,
188                               int64_t       ne02,
189                               int64_t       ne03,
190                               size_t        nb02,
191                               size_t        nb03,
192                               size_t        nb12,
193                               size_t        nb13,
194                               size_t        nb2,
195                               size_t        nb3,
196                               cudaStream_t  stream) {
197    const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
198    dim3        threads(WARP_SIZE, k);
199    dim3        grid(ne02 * ne03);
200    if (n == 64) {
201        switch (k) {
202            case 32:
203                solve_tri_f32_fast<64, 32>
204                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
205                break;
206            case 16:
207                solve_tri_f32_fast<64, 16>
208                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
209                break;
210            case 14:
211                solve_tri_f32_fast<64, 14>
212                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
213                break;
214            case 12:
215                solve_tri_f32_fast<64, 12>
216                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
217                break;
218            case 10:
219                solve_tri_f32_fast<64, 10>
220                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
221                break;
222            case 8:
223                solve_tri_f32_fast<64, 8>
224                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
225                break;
226            case 6:
227                solve_tri_f32_fast<64, 6>
228                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
229                break;
230            case 4:
231                solve_tri_f32_fast<64, 4>
232                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
233                break;
234            case 2:
235                solve_tri_f32_fast<64, 2>
236                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
237                break;
238            case 1:
239                solve_tri_f32_fast<64, 1>
240                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
241                break;
242            default:
243                solve_tri_f32_fast<0, 0>
244                    <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
245        }
246    } else {  // run general case
247        solve_tri_f32_fast<0, 0>
248            <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
249    }
250}
251
252void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253    const ggml_tensor * src0 = dst->src[0];  // A (n×n, lower triangular)
254    const ggml_tensor * src1 = dst->src[1];  // B (n×k)
255
256    ggml_is_contiguous(src0);
257    ggml_is_contiguous(src1);
258
259    const int64_t n    = src0->ne[0];
260    const int64_t k    = src1->ne[0];
261    const int64_t ne02 = src0->ne[2];
262    const int64_t ne03 = src0->ne[3];
263
264    if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
265        solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
266                           src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
267                           src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
268                           dst->nb[3] / sizeof(float), ctx.stream());
269    } else {
270        solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
271                             ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
272                             src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
273                             dst->nb[3] / sizeof(float), ctx.stream());
274    }
275}