1#include "norm.cuh"
2#include <cstdint>
3
4template <int block_size>
5static __global__ void norm_f32(
6 const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
7 const int64_t stride_sample, const float eps) {
8 const int nrows = gridDim.x;
9 const int nchannels = gridDim.y;
10
11 const int row = blockIdx.x;
12 const int channel = blockIdx.y;
13 const int sample = blockIdx.z;
14 const int tid = threadIdx.x;
15
16 x += sample*stride_sample + channel*stride_channel + row*stride_row;
17 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
18
19 float2 mean_var = make_float2(0.0f, 0.0f);
20
21 for (int col = tid; col < ncols; col += block_size) {
22 const float xi = x[col];
23 mean_var.x += xi;
24 mean_var.y += xi * xi;
25 }
26
27 // sum up partial sums
28 extern __shared__ float2 s_sum2[];
29 mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
30
31 const float mean = mean_var.x / ncols;
32 const float var = mean_var.y / ncols - mean * mean;
33 const float inv_std = rsqrtf(var + eps);
34
35 for (int col = tid; col < ncols; col += block_size) {
36 dst[col] = (x[col] - mean) * inv_std;
37 }
38}
39
40template <int block_size>
41static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
42 // blockIdx.x: num_groups idx
43 // threadIdx.x: block_size idx
44 const int start = blockIdx.x*group_size + threadIdx.x;
45 const int end = min(blockIdx.x*group_size + group_size, ne_elements);
46
47 float tmp = 0.0f; // partial sum for thread in warp
48
49 for (int j = start; j < end; j += block_size) {
50 tmp += x[j];
51 }
52
53 extern __shared__ float s_sum[];
54 tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
55
56 const float mean = tmp / group_size;
57 tmp = 0.0f;
58
59 for (int j = start; j < end; j += block_size) {
60 const float xi = x[j] - mean;
61 dst[j] = xi;
62 tmp += xi * xi;
63 }
64
65 tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
66
67 const float variance = tmp / group_size;
68 const float scale = rsqrtf(variance + eps);
69 for (int j = start; j < end; j += block_size) {
70 dst[j] *= scale;
71 }
72}
73
74template <int block_size, bool do_multiply = false, bool do_add = false>
75static __global__ void rms_norm_f32(const float * x,
76 float * dst,
77 const int ncols,
78 const int64_t stride_row,
79 const int64_t stride_channel,
80 const int64_t stride_sample,
81 const float eps,
82 const float * mul = nullptr,
83 const int64_t mul_stride_row = 0,
84 const int64_t mul_stride_channel = 0,
85 const int64_t mul_stride_sample = 0,
86 const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
87 const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
88 const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
89 const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
90 const float * add = nullptr,
91 const int64_t add_stride_row = 0,
92 const int64_t add_stride_channel = 0,
93 const int64_t add_stride_sample = 0,
94 const uint3 add_ncols_packed = make_uint3(0, 0, 0),
95 const uint3 add_nrows_packed = make_uint3(0, 0, 0),
96 const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
97 const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
98 const int nrows = gridDim.x;
99 const int nchannels = gridDim.y;
100
101 const int row = blockIdx.x;
102 const int channel = blockIdx.y;
103 const int sample = blockIdx.z;
104 const int tid = threadIdx.x;
105
106 static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
107
108 x += sample*stride_sample + channel*stride_channel + row*stride_row;
109 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
110
111 if constexpr (do_multiply) {
112 const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
113 const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
114 const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
115 mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
116 }
117
118 if constexpr (do_add) {
119 const int add_row = fastmodulo(row, add_nrows_packed);
120 const int add_channel = fastmodulo(channel, add_nchannels_packed);
121 const int add_sample = fastmodulo(sample, add_nsamples_packed);
122 add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
123 }
124
125 float tmp = 0.0f; // partial sum for thread in warp
126
127 for (int col = tid; col < ncols; col += block_size) {
128 const float xi = x[col];
129 tmp += xi * xi;
130 }
131
132 // sum up partial sums
133 extern __shared__ float s_sum[];
134 tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
135
136 const float mean = tmp / ncols;
137 const float scale = rsqrtf(mean + eps);
138
139 for (int col = tid; col < ncols; col += block_size) {
140 if constexpr (do_multiply && do_add) {
141 const int mul_col = fastmodulo(col, mul_ncols_packed);
142 const int add_col = fastmodulo(col, add_ncols_packed);
143 dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
144 } else if constexpr (do_multiply) {
145 const int mul_col = fastmodulo(col, mul_ncols_packed);
146 dst[col] = scale * x[col] * mul[mul_col];
147 } else {
148 dst[col] = scale * x[col];
149 }
150 }
151}
152
153template <int block_size>
154static __global__ void rms_norm_back_f32(
155 const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
156 const int row = blockIdx.x*blockDim.y + threadIdx.y;
157 const int tid = threadIdx.x;
158
159 grad += int64_t(row)*ncols;
160 xf += int64_t(row)*ncols;
161 dst += int64_t(row)*ncols;
162
163 float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
164 float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
165
166 for (int col = tid; col < ncols; col += block_size) {
167 const float xfi = xf[col];
168 sum_xx += xfi * xfi;
169 sum_xg += xfi * grad[col];
170 }
171
172 // sum up partial sums
173 sum_xx = warp_reduce_sum(sum_xx);
174 sum_xg = warp_reduce_sum(sum_xg);
175 if constexpr (block_size > WARP_SIZE) {
176 static_assert(block_size == 1024, "unexpected block_size");
177 __shared__ float s_sum_xx[32];
178 __shared__ float s_sum_xg[32];
179 const int warp_id = threadIdx.x / WARP_SIZE;
180 const int lane_id = threadIdx.x % WARP_SIZE;
181 if (lane_id == 0) {
182 s_sum_xx[warp_id] = sum_xx;
183 s_sum_xg[warp_id] = sum_xg;
184 }
185 __syncthreads();
186
187 sum_xx = s_sum_xx[lane_id];
188 sum_xx = warp_reduce_sum(sum_xx);
189
190 sum_xg = s_sum_xg[lane_id];
191 sum_xg = warp_reduce_sum(sum_xg);
192 }
193
194 const float mean_eps = sum_xx / ncols + eps;
195 const float sum_eps = sum_xx + ncols*eps;
196
197 const float scale_grad = rsqrtf(mean_eps);
198 const float scale_x = -scale_grad * sum_xg/sum_eps;
199
200 for (int col = tid; col < ncols; col += block_size) {
201 dst[col] = scale_grad*grad[col] + scale_x*xf[col];
202 }
203}
204
205// template <int block_size>
206// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
207// const int row = blockIdx.x*blockDim.y + threadIdx.y;
208// const int tid = threadIdx.x;
209
210// float tmp = 0.0f; // partial sum for thread in warp
211
212// for (int col = tid; col < ncols; col += block_size) {
213// const float xi = x[row*ncols + col];
214// tmp += xi * xi;
215// }
216
217// // sum up partial sums
218// tmp = warp_reduce_sum(tmp);
219// if (block_size > WARP_SIZE) {
220// __shared__ float s_sum[32];
221// int warp_id = threadIdx.x / WARP_SIZE;
222// int lane_id = threadIdx.x % WARP_SIZE;
223// if (lane_id == 0) {
224// s_sum[warp_id] = tmp;
225// }
226// __syncthreads();
227// tmp = s_sum[lane_id];
228// tmp = warp_reduce_sum(tmp);
229// }
230
231// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
232// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
233
234// for (int col = tid; col < ncols; col += block_size) {
235// dst[row*ncols + col] = scale * x[row*ncols + col];
236// }
237// }
238
239template <int block_size>
240static __global__ void l2_norm_f32(
241 const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
242 const int64_t stride_sample, const float eps) {
243 const int nrows = gridDim.x;
244 const int nchannels = gridDim.y;
245
246 const int row = blockIdx.x;
247 const int channel = blockIdx.y;
248 const int sample = blockIdx.z;
249 const int tid = threadIdx.x;
250
251 x += sample*stride_sample + channel*stride_channel + row*stride_row;
252 dst += ((sample*nchannels + channel)*nrows + row)*ncols;
253
254 float tmp = 0.0f; // partial sum for thread in warp
255
256 for (int col = tid; col < ncols; col += block_size) {
257 const float xi = x[col];
258 tmp += xi * xi;
259 }
260
261 // sum up partial sums
262 extern __shared__ float s_sum[];
263 tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
264
265 // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
266 const float scale = rsqrtf(fmaxf(tmp, eps * eps));
267
268 for (int col = tid; col < ncols; col += block_size) {
269 dst[col] = scale * x[col];
270 }
271}
272
273static void norm_f32_cuda(
274 const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
275 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
276 const dim3 blocks_num(nrows, nchannels, nsamples);
277 if (ncols < 1024) {
278 const dim3 block_dims(WARP_SIZE, 1, 1);
279 norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
280 } else {
281 const dim3 block_dims(1024, 1, 1);
282 norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
283 }
284}
285
286static void group_norm_f32_cuda(
287 const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
288 if (group_size < 1024) {
289 const dim3 block_dims(WARP_SIZE, 1, 1);
290 group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
291 } else {
292 const dim3 block_dims(1024, 1, 1);
293 group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
294 }
295}
296
297static void rms_norm_f32_cuda(
298 const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
299 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
300 const dim3 blocks_num(nrows, nchannels, nsamples);
301 if (ncols < 1024) {
302 const dim3 block_dims(256, 1, 1);
303 rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
304 } else {
305 const dim3 block_dims(1024, 1, 1);
306 rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
307 }
308}
309
310static void rms_norm_mul_f32_cuda(const float * x,
311 const float * mul,
312 const float * add,
313 float * dst,
314 const int ncols,
315 const int nrows,
316 const int nchannels,
317 const int nsamples,
318 const int64_t stride_row,
319 const int64_t stride_channel,
320 const int64_t stride_sample,
321 const int64_t mul_stride_row,
322 const int64_t mul_stride_channel,
323 const int64_t mul_stride_sample,
324 const uint32_t mul_ncols,
325 const uint32_t mul_nrows,
326 const uint32_t mul_nchannels,
327 const uint32_t mul_nsamples,
328 const int64_t add_stride_row,
329 const int64_t add_stride_channel,
330 const int64_t add_stride_sample,
331 const uint32_t add_ncols,
332 const uint32_t add_nrows,
333 const uint32_t add_nchannels,
334 const uint32_t add_nsamples,
335 const float eps,
336 cudaStream_t stream) {
337 const dim3 blocks_num(nrows, nchannels, nsamples);
338 if (mul == nullptr) {
339 rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
340 return;
341 }
342 if (add == nullptr) {
343 const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
344 const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
345 const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
346 const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
347 if (ncols < 1024) {
348 const dim3 block_dims(256, 1, 1);
349 rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
350 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
351 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
352 } else {
353 const dim3 block_dims(1024, 1, 1);
354 rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
355 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
356 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
357 }
358 } else {
359 const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
360 const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
361 const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
362 const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
363
364 const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
365 const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
366 const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
367 const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
368 if (ncols < 1024) {
369 const dim3 block_dims(256, 1, 1);
370 rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
371 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
372 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
373 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
374 add_nchannels_packed, add_nsamples_packed);
375 } else {
376 const dim3 block_dims(1024, 1, 1);
377 rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
378 x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
379 mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
380 add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
381 add_nchannels_packed, add_nsamples_packed);
382 }
383 }
384}
385
386static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
387 if (ncols < 1024) {
388 const dim3 block_dims(WARP_SIZE, 1, 1);
389 rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
390 } else {
391 const dim3 block_dims(1024, 1, 1);
392 rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
393 }
394}
395
396static void l2_norm_f32_cuda(
397 const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
398 const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
399 const dim3 blocks_num(nrows, nchannels, nsamples);
400 if (ncols < 1024) {
401 const dim3 block_dims(WARP_SIZE, 1, 1);
402 l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
403 } else {
404 const dim3 block_dims(1024, 1, 1);
405 l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
406 }
407}
408
409void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
410 const ggml_tensor * src0 = dst->src[0];
411 const float * src0_d = (const float *) src0->data;
412 float * dst_d = (float *) dst->data;
413 cudaStream_t stream = ctx.stream();
414
415 GGML_ASSERT(src0->type == GGML_TYPE_F32);
416 GGML_ASSERT( dst->type == GGML_TYPE_F32);
417
418 GGML_TENSOR_UNARY_OP_LOCALS;
419
420 float eps;
421 memcpy(&eps, dst->op_params, sizeof(float));
422 GGML_ASSERT(eps >= 0.0f);
423
424 const size_t ts0 = ggml_type_size(src0->type);
425 GGML_ASSERT(nb00 == ts0);
426 const int64_t s01 = nb01 / ts0;
427 const int64_t s02 = nb02 / ts0;
428 const int64_t s03 = nb03 / ts0;
429
430 norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
431}
432
433void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
434 const ggml_tensor * src0 = dst->src[0];
435 const float * src0_d = (const float *)src0->data;
436 float * dst_d = (float *)dst->data;
437 cudaStream_t stream = ctx.stream();
438
439 GGML_ASSERT(src0->type == GGML_TYPE_F32);
440 GGML_ASSERT( dst->type == GGML_TYPE_F32);
441
442 int num_groups = dst->op_params[0];
443
444 float eps;
445 memcpy(&eps, dst->op_params + 1, sizeof(float));
446 GGML_ASSERT(eps >= 0.0f);
447
448 int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
449 group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
450}
451
452void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
453 const ggml_tensor * src0 = dst->src[0];
454 const float * src0_d = (const float *) src0->data;
455 float * dst_d = (float *) dst->data;
456 cudaStream_t stream = ctx.stream();
457
458 GGML_ASSERT(src0->type == GGML_TYPE_F32);
459 GGML_ASSERT( dst->type == GGML_TYPE_F32);
460
461 GGML_TENSOR_UNARY_OP_LOCALS;
462
463 float eps;
464 memcpy(&eps, dst->op_params, sizeof(float));
465 GGML_ASSERT(eps >= 0.0f);
466
467 const size_t ts0 = ggml_type_size(src0->type);
468 GGML_ASSERT(nb00 == ts0);
469 const int64_t s01 = nb01 / ts0;
470 const int64_t s02 = nb02 / ts0;
471 const int64_t s03 = nb03 / ts0;
472
473 rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
474}
475
476void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
477 const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
478 float eps = 0.0f;
479
480 memcpy(&eps, dst->op_params, sizeof(float));
481
482 const float * src0_d = (const float *) rms_norm_src->data;
483 const float * mul_d = nullptr;
484 const ggml_tensor * mul_src = nullptr;
485
486 if (mul_tensor->src[0] == dst) {
487 mul_d = (float *) mul_tensor->src[1]->data;
488 mul_src = mul_tensor->src[1];
489 } else if(mul_tensor->src[1] == dst) {
490 mul_d = (float *) mul_tensor->src[0]->data;
491 mul_src = mul_tensor->src[0];
492 } else {
493 GGML_ASSERT(false);
494 }
495
496 float * dst_d = (float *) mul_tensor->data;
497 cudaStream_t stream = ctx.stream();
498
499 GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
500 GGML_ASSERT(dst->type == GGML_TYPE_F32);
501 GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
502 GGML_ASSERT(eps >= 0.0f);
503
504 const int64_t ne00 = rms_norm_src->ne[0];
505 const int64_t ne01 = rms_norm_src->ne[1];
506 const int64_t ne02 = rms_norm_src->ne[2];
507 const int64_t ne03 = rms_norm_src->ne[3];
508
509 const size_t ts0 = ggml_type_size(rms_norm_src->type);
510 GGML_ASSERT(rms_norm_src->nb[0] == ts0);
511 const int64_t s01 = rms_norm_src->nb[1] / ts0;
512 const int64_t s02 = rms_norm_src->nb[2] / ts0;
513 const int64_t s03 = rms_norm_src->nb[3] / ts0;
514
515 const size_t ts_mul = ggml_type_size(mul_src->type);
516 GGML_ASSERT(mul_src->nb[0] == ts_mul);
517 const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
518 const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
519 const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
520
521 const int mul_ncols = mul_src->ne[0];
522 const int mul_nrows = mul_src->ne[1];
523 const int mul_nchannels = mul_src->ne[2];
524 const int mul_nsamples = mul_src->ne[3];
525
526 rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
527 ne00, ne01, ne02, ne03,
528 /*s00*/ s01, s02, s03,
529 /*mul_s00*/ mul_s01, mul_s02, mul_s03,
530 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
531 /*add_s00*/ 0, 0, 0,
532 0, 0, 0, 0,
533 eps, stream);
534}
535
536void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
537 ggml_tensor * dst,
538 ggml_tensor * mul_tensor,
539 ggml_tensor * add_tensor) {
540 const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
541 float eps = 0.0f;
542
543 memcpy(&eps, dst->op_params, sizeof(float));
544
545 const float * src0_d = (const float *) rms_norm_src->data;
546 const float * mul_d = nullptr;
547 const ggml_tensor * mul_src = nullptr;
548
549 if (mul_tensor->src[0] == dst) {
550 mul_d = (float *) mul_tensor->src[1]->data;
551 mul_src = mul_tensor->src[1];
552 } else if (mul_tensor->src[1] == dst) {
553 mul_d = (float *) mul_tensor->src[0]->data;
554 mul_src = mul_tensor->src[0];
555 } else {
556 GGML_ASSERT(false);
557 }
558
559 const float * add_d = nullptr;
560 const ggml_tensor * add_src = nullptr;
561
562 if (add_tensor->src[0] == mul_tensor) {
563 add_d = (float *) add_tensor->src[1]->data;
564 add_src = add_tensor->src[1];
565 } else if (add_tensor->src[1] == mul_tensor) {
566 add_d = (float *) add_tensor->src[0]->data;
567 add_src = add_tensor->src[0];
568 } else {
569 GGML_ASSERT(false);
570 }
571
572 float * dst_d = (float *) add_tensor->data;
573 cudaStream_t stream = ctx.stream();
574
575 GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
576 GGML_ASSERT(dst->type == GGML_TYPE_F32);
577 GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
578 GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
579 GGML_ASSERT(eps >= 0.0f);
580
581 const int64_t ne00 = rms_norm_src->ne[0];
582 const int64_t ne01 = rms_norm_src->ne[1];
583 const int64_t ne02 = rms_norm_src->ne[2];
584 const int64_t ne03 = rms_norm_src->ne[3];
585
586 const size_t ts0 = ggml_type_size(rms_norm_src->type);
587 GGML_ASSERT(rms_norm_src->nb[0] == ts0);
588 const int64_t s01 = rms_norm_src->nb[1] / ts0;
589 const int64_t s02 = rms_norm_src->nb[2] / ts0;
590 const int64_t s03 = rms_norm_src->nb[3] / ts0;
591
592 const size_t ts_mul = ggml_type_size(mul_src->type);
593 GGML_ASSERT(mul_src->nb[0] == ts_mul);
594 const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
595 const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
596 const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
597
598 const int mul_ncols = mul_src->ne[0];
599 const int mul_nrows = mul_src->ne[1];
600 const int mul_nchannels = mul_src->ne[2];
601 const int mul_nsamples = mul_src->ne[3];
602
603 const size_t ts_add = ggml_type_size(add_src->type);
604 GGML_ASSERT(add_src->nb[0] == ts_add);
605 const int64_t add_s01 = add_src->nb[1] / ts_add;
606 const int64_t add_s02 = add_src->nb[2] / ts_add;
607 const int64_t add_s03 = add_src->nb[3] / ts_add;
608
609 const int add_ncols = add_src->ne[0];
610 const int add_nrows = add_src->ne[1];
611 const int add_nchannels = add_src->ne[2];
612 const int add_nsamples = add_src->ne[3];
613
614 rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
615 ne00,ne01, ne02, ne03,
616 /*s00*/ s01, s02, s03,
617 /*mul_s00*/ mul_s01, mul_s02, mul_s03,
618 mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
619 /*add_s00*/ add_s01, add_s02, add_s03,
620 add_ncols, add_nrows, add_nchannels, add_nsamples,
621 eps, stream);
622}
623
624void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
625 const ggml_tensor * grad = dst->src[0]; // gradients
626 const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
627
628 const float * grad_d = (const float *) grad->data;
629 const float * src0f_d = (const float *) src0f->data;
630 float * dst_d = (float *) dst->data;
631
632 cudaStream_t stream = ctx.stream();
633
634 GGML_ASSERT(ggml_is_contiguous(grad));
635
636 GGML_ASSERT( grad->type == GGML_TYPE_F32);
637 GGML_ASSERT(src0f->type == GGML_TYPE_F32);
638 GGML_ASSERT( dst->type == GGML_TYPE_F32);
639
640 const int64_t ne00 = src0f->ne[0];
641 const int64_t nrows = ggml_nrows(src0f);
642
643 float eps;
644 memcpy(&eps, dst->op_params, sizeof(float));
645 GGML_ASSERT(eps >= 0.0f);
646
647 rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
648}
649
650void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
651 const ggml_tensor * src0 = dst->src[0];
652 const float * src0_d = (const float *) src0->data;
653 float * dst_d = (float *) dst->data;
654 cudaStream_t stream = ctx.stream();
655
656 GGML_ASSERT(src0->type == GGML_TYPE_F32);
657 GGML_ASSERT( dst->type == GGML_TYPE_F32);
658
659 GGML_TENSOR_UNARY_OP_LOCALS;
660
661 float eps;
662 memcpy(&eps, dst->op_params, sizeof(float));
663 GGML_ASSERT(eps >= 0.0f);
664
665 const size_t ts0 = ggml_type_size(src0->type);
666 GGML_ASSERT(nb00 == ts0);
667 const int64_t s01 = nb01 / ts0;
668 const int64_t s02 = nb02 / ts0;
669 const int64_t s03 = nb03 / ts0;
670
671 l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
672}