1#include "common.cuh"
2#include "ggml.h"
3#include "solve_tri.cuh"
4
5#define MAX_N_FAST 64
6#define MAX_K_FAST 32
7
8static __global__ void get_batch_pointers(const float * A,
9 float * X,
10 const float ** A_ptrs,
11 float ** X_ptrs,
12 int64_t ne02,
13 int64_t total_batches,
14 size_t s02,
15 size_t s03,
16 size_t s2,
17 size_t s3) {
18 const int idx = blockIdx.x * blockDim.x + threadIdx.x;
19 if (idx >= total_batches) {
20 return;
21 }
22
23 const int64_t i3 = idx / ne02;
24 const int64_t i2 = idx % ne02;
25
26 A_ptrs[idx] = A + i3 * s03 + i2 * s02;
27 X_ptrs[idx] = X + i3 * s3 + i2 * s2;
28}
29
30static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
31 const float * A,
32 const float * B,
33 float * X,
34 int n,
35 int k,
36 int64_t ne02,
37 int64_t ne03,
38 size_t s02,
39 size_t s03,
40 size_t s12,
41 size_t s13,
42 size_t s2,
43 size_t s3,
44 cudaStream_t stream) {
45 const float alpha = 1.0f;
46 const int64_t total_batches = ne02 * ne03;
47 if (total_batches == 0) {
48 return;
49 }
50
51 // Bulk copy B -> X (contiguous tensors)
52 if (X != B) {
53 const int64_t total_elements_BX = n * k * total_batches;
54 CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
55 }
56
57 const int id = ggml_cuda_get_device();
58
59 ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
60 ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
61
62 const float ** A_ptrs_dev = A_ptrs_alloc.get();
63 float ** X_ptrs_dev = X_ptrs_alloc.get();
64
65 get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
66 total_batches, s02, s03, s2, s3);
67
68 CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
69
70 // Yes, this is necessary, without this we get RMSE errors
71 CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
72 CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
73 CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
74
75 // revert to standard mode from common.cuh
76 CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
77
78 GGML_UNUSED_VARS(s12, s13);
79}
80
81// ======================
82// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
83// ======================
84// When ncols_template == 0 the bounds for the loops in this function are not
85// known and can't be unrolled. As we want to keep pragma unroll for all other
86// cases we supress the clang transformation warning here.
87#ifdef __clang__
88# pragma clang diagnostic push
89# pragma clang diagnostic ignored "-Wpass-failed"
90#endif // __clang__
91template <int n_template, int k_template>
92static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
93 const float * __restrict__ B,
94 float * __restrict__ X,
95 const uint3 ne02,
96 const size_t nb02,
97 const size_t nb03,
98 const size_t nb12,
99 const size_t nb13,
100 const size_t nb2,
101 const size_t nb3,
102 const int n_arg,
103 const int k_arg) {
104 const int n = n_template == 0 ? n_arg : n_template;
105 const int k = k_template == 0 ? k_arg : k_template;
106
107 const int batch_idx = blockIdx.x;
108 const int lane = threadIdx.x;
109 const int col_idx = threadIdx.y;
110
111 if (col_idx >= k) {
112 return;
113 }
114
115 const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
116 const int64_t i02 = i02_i03.y;
117 const int64_t i03 = i02_i03.x;
118
119 const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
120 const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
121 float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
122
123 __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
124
125 const int offset = threadIdx.x + threadIdx.y * blockDim.x;
126
127#pragma unroll
128 for (int i = 0; i < n * n; i += k * WARP_SIZE) {
129 const int i0 = i + offset;
130 if (i0 < n * n) {
131 sA[i0] = A_batch[i0];
132 }
133 }
134
135 __syncthreads();
136
137 float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
138 float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
139
140 const int half = WARP_SIZE;
141 const int nrows_low = (n < half) ? n : half;
142
143#pragma unroll
144 for (int row = 0; row < nrows_low; ++row) {
145 float sum = 0.0f;
146 if (lane < row) {
147 sum += sA[row * n + lane] * x_low;
148 }
149 sum = warp_reduce_sum(sum);
150
151 if (lane == row) {
152 x_low = (x_low - sum) / sA[row * n + row];
153 }
154 }
155
156#pragma unroll
157 for (int row = half; row < n; ++row) {
158 float sum = sA[row * n + lane] * x_low;
159 const int j = half + lane;
160 if (j < row) {
161 sum += sA[row * n + j] * x_high;
162 }
163 sum = warp_reduce_sum(sum);
164
165 if (lane == row - half) {
166 x_high = (x_high - sum) / sA[row * n + row];
167 }
168 }
169
170#pragma unroll
171 for (int rr = 0; rr < 2; ++rr) {
172 const int row = rr * WARP_SIZE + lane;
173 if (row < n) {
174 const float val = (row < half) ? x_low : x_high;
175 X_batch[row * k + col_idx] = val;
176 }
177 }
178}
179#ifdef __clang__
180# pragma clang diagnostic pop
181#endif // __clang__
182
183static void solve_tri_f32_cuda(const float * A,
184 const float * B,
185 float * X,
186 int n,
187 int k,
188 int64_t ne02,
189 int64_t ne03,
190 size_t nb02,
191 size_t nb03,
192 size_t nb12,
193 size_t nb13,
194 size_t nb2,
195 size_t nb3,
196 cudaStream_t stream) {
197 const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
198 dim3 threads(WARP_SIZE, k);
199 dim3 grid(ne02 * ne03);
200 if (n == 64) {
201 switch (k) {
202 case 32:
203 solve_tri_f32_fast<64, 32>
204 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
205 break;
206 case 16:
207 solve_tri_f32_fast<64, 16>
208 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
209 break;
210 case 14:
211 solve_tri_f32_fast<64, 14>
212 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
213 break;
214 case 12:
215 solve_tri_f32_fast<64, 12>
216 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
217 break;
218 case 10:
219 solve_tri_f32_fast<64, 10>
220 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
221 break;
222 case 8:
223 solve_tri_f32_fast<64, 8>
224 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
225 break;
226 case 6:
227 solve_tri_f32_fast<64, 6>
228 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
229 break;
230 case 4:
231 solve_tri_f32_fast<64, 4>
232 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
233 break;
234 case 2:
235 solve_tri_f32_fast<64, 2>
236 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
237 break;
238 case 1:
239 solve_tri_f32_fast<64, 1>
240 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
241 break;
242 default:
243 solve_tri_f32_fast<0, 0>
244 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
245 }
246 } else { // run general case
247 solve_tri_f32_fast<0, 0>
248 <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
249 }
250}
251
252void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
253 const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
254 const ggml_tensor * src1 = dst->src[1]; // B (n×k)
255
256 ggml_is_contiguous(src0);
257 ggml_is_contiguous(src1);
258
259 const int64_t n = src0->ne[0];
260 const int64_t k = src1->ne[0];
261 const int64_t ne02 = src0->ne[2];
262 const int64_t ne03 = src0->ne[3];
263
264 if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
265 solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
266 src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
267 src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
268 dst->nb[3] / sizeof(float), ctx.stream());
269 } else {
270 solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
271 ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
272 src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
273 dst->nb[3] / sizeof(float), ctx.stream());
274 }
275}