1#include "norm.cuh"
  2#include <cstdint>
  3
  4template <int block_size>
  5static __global__ void norm_f32(
  6        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
  7        const int64_t stride_sample, const float eps) {
  8    const int nrows     = gridDim.x;
  9    const int nchannels = gridDim.y;
 10
 11    const int row       = blockIdx.x;
 12    const int channel   = blockIdx.y;
 13    const int sample    = blockIdx.z;
 14    const int tid       = threadIdx.x;
 15
 16    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
 17    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
 18
 19    float2 mean_var = make_float2(0.0f, 0.0f);
 20
 21    for (int col = tid; col < ncols; col += block_size) {
 22        const float xi = x[col];
 23        mean_var.x += xi;
 24        mean_var.y += xi * xi;
 25    }
 26
 27    // sum up partial sums
 28    extern __shared__ float2 s_sum2[];
 29    mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
 30
 31    const float mean = mean_var.x / ncols;
 32    const float var = mean_var.y / ncols - mean * mean;
 33    const float inv_std = rsqrtf(var + eps);
 34
 35    for (int col = tid; col < ncols; col += block_size) {
 36        dst[col] = (x[col] - mean) * inv_std;
 37    }
 38}
 39
 40template <int block_size>
 41static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
 42    // blockIdx.x: num_groups idx
 43    // threadIdx.x: block_size idx
 44    const int start =     blockIdx.x*group_size + threadIdx.x;
 45    const int end   = min(blockIdx.x*group_size + group_size,  ne_elements);
 46
 47    float tmp = 0.0f; // partial sum for thread in warp
 48
 49    for (int j = start; j < end; j += block_size) {
 50        tmp += x[j];
 51    }
 52
 53    extern __shared__ float s_sum[];
 54    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
 55
 56    const float mean = tmp / group_size;
 57    tmp = 0.0f;
 58
 59    for (int j = start; j < end; j += block_size) {
 60        const float xi = x[j] - mean;
 61        dst[j] = xi;
 62        tmp += xi * xi;
 63    }
 64
 65    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
 66
 67    const float variance = tmp / group_size;
 68    const float scale = rsqrtf(variance + eps);
 69    for (int j = start; j < end; j += block_size) {
 70        dst[j] *= scale;
 71    }
 72}
 73
 74template <int block_size, bool do_multiply = false, bool do_add = false>
 75static __global__ void rms_norm_f32(const float * x,
 76                                    float *       dst,
 77                                    const int     ncols,
 78                                    const int64_t stride_row,
 79                                    const int64_t stride_channel,
 80                                    const int64_t stride_sample,
 81                                    const float   eps,
 82                                    const float * mul                  = nullptr,
 83                                    const int64_t mul_stride_row       = 0,
 84                                    const int64_t mul_stride_channel   = 0,
 85                                    const int64_t mul_stride_sample    = 0,
 86                                    const uint3   mul_ncols_packed     = make_uint3(0, 0, 0),
 87                                    const uint3   mul_nrows_packed     = make_uint3(0, 0, 0),
 88                                    const uint3   mul_nchannels_packed = make_uint3(0, 0, 0),
 89                                    const uint3   mul_nsamples_packed  = make_uint3(0, 0, 0),
 90                                    const float * add                  = nullptr,
 91                                    const int64_t add_stride_row       = 0,
 92                                    const int64_t add_stride_channel   = 0,
 93                                    const int64_t add_stride_sample    = 0,
 94                                    const uint3   add_ncols_packed     = make_uint3(0, 0, 0),
 95                                    const uint3   add_nrows_packed     = make_uint3(0, 0, 0),
 96                                    const uint3   add_nchannels_packed = make_uint3(0, 0, 0),
 97                                    const uint3   add_nsamples_packed  = make_uint3(0, 0, 0)) {
 98    const int nrows     = gridDim.x;
 99    const int nchannels = gridDim.y;
100
101    const int row       = blockIdx.x;
102    const int channel   = blockIdx.y;
103    const int sample    = blockIdx.z;
104    const int tid       = threadIdx.x;
105
106    static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
107
108    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
109    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
110
111    if constexpr (do_multiply) {
112        const uint32_t mul_row     = fastmodulo(row, mul_nrows_packed);
113        const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
114        const uint32_t mul_sample  = fastmodulo(sample, mul_nsamples_packed);
115        mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
116    }
117
118    if constexpr (do_add) {
119        const int add_row     = fastmodulo(row, add_nrows_packed);
120        const int add_channel = fastmodulo(channel, add_nchannels_packed);
121        const int add_sample  = fastmodulo(sample, add_nsamples_packed);
122        add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
123    }
124
125    float tmp = 0.0f; // partial sum for thread in warp
126
127    for (int col = tid; col < ncols; col += block_size) {
128        const float xi = x[col];
129        tmp += xi * xi;
130    }
131
132    // sum up partial sums
133    extern __shared__ float s_sum[];
134    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
135
136    const float mean = tmp / ncols;
137    const float scale = rsqrtf(mean + eps);
138
139    for (int col = tid; col < ncols; col += block_size) {
140        if constexpr (do_multiply && do_add) {
141            const int mul_col = fastmodulo(col, mul_ncols_packed);
142            const int add_col = fastmodulo(col, add_ncols_packed);
143            dst[col]          = scale * x[col] * mul[mul_col] + add[add_col];
144        } else if constexpr (do_multiply) {
145            const int mul_col = fastmodulo(col, mul_ncols_packed);
146            dst[col]          = scale * x[col] * mul[mul_col];
147        } else {
148            dst[col] = scale * x[col];
149        }
150    }
151}
152
153template <int block_size>
154static __global__ void rms_norm_back_f32(
155        const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
156    const int row = blockIdx.x*blockDim.y + threadIdx.y;
157    const int tid = threadIdx.x;
158
159    grad += int64_t(row)*ncols;
160    xf   += int64_t(row)*ncols;
161    dst  += int64_t(row)*ncols;
162
163    float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
164    float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
165
166    for (int col = tid; col < ncols; col += block_size) {
167        const float xfi = xf[col];
168        sum_xx += xfi * xfi;
169        sum_xg += xfi * grad[col];
170    }
171
172    // sum up partial sums
173    sum_xx = warp_reduce_sum(sum_xx);
174    sum_xg = warp_reduce_sum(sum_xg);
175    if constexpr (block_size > WARP_SIZE) {
176        static_assert(block_size == 1024, "unexpected block_size");
177        __shared__ float s_sum_xx[32];
178        __shared__ float s_sum_xg[32];
179        const int warp_id = threadIdx.x / WARP_SIZE;
180        const int lane_id = threadIdx.x % WARP_SIZE;
181        if (lane_id == 0) {
182            s_sum_xx[warp_id] = sum_xx;
183            s_sum_xg[warp_id] = sum_xg;
184        }
185        __syncthreads();
186
187        sum_xx = s_sum_xx[lane_id];
188        sum_xx = warp_reduce_sum(sum_xx);
189
190        sum_xg = s_sum_xg[lane_id];
191        sum_xg = warp_reduce_sum(sum_xg);
192    }
193
194    const float mean_eps = sum_xx / ncols + eps;
195    const float sum_eps  = sum_xx + ncols*eps;
196
197    const float scale_grad = rsqrtf(mean_eps);
198    const float scale_x    = -scale_grad * sum_xg/sum_eps;
199
200    for (int col = tid; col < ncols; col += block_size) {
201        dst[col] = scale_grad*grad[col] + scale_x*xf[col];
202    }
203}
204
205// template <int block_size>
206// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
207//     const int row = blockIdx.x*blockDim.y + threadIdx.y;
208//     const int tid = threadIdx.x;
209
210//     float tmp = 0.0f; // partial sum for thread in warp
211
212//     for (int col = tid; col < ncols; col += block_size) {
213//         const float xi = x[row*ncols + col];
214//         tmp += xi * xi;
215//     }
216
217//     // sum up partial sums
218//     tmp = warp_reduce_sum(tmp);
219//     if (block_size > WARP_SIZE) {
220//         __shared__ float s_sum[32];
221//         int warp_id = threadIdx.x / WARP_SIZE;
222//         int lane_id = threadIdx.x % WARP_SIZE;
223//         if (lane_id == 0) {
224//             s_sum[warp_id] = tmp;
225//         }
226//         __syncthreads();
227//         tmp = s_sum[lane_id];
228//         tmp = warp_reduce_sum(tmp);
229//     }
230
231//     // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
232//     const float scale = rsqrtf(fmaxf(tmp, eps * eps));
233
234//     for (int col = tid; col < ncols; col += block_size) {
235//         dst[row*ncols + col] = scale * x[row*ncols + col];
236//     }
237// }
238
239template <int block_size>
240static __global__ void l2_norm_f32(
241        const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
242        const int64_t stride_sample, const float eps) {
243    const int nrows     = gridDim.x;
244    const int nchannels = gridDim.y;
245
246    const int row       = blockIdx.x;
247    const int channel   = blockIdx.y;
248    const int sample    = blockIdx.z;
249    const int tid       = threadIdx.x;
250
251    x   += sample*stride_sample + channel*stride_channel + row*stride_row;
252    dst += ((sample*nchannels + channel)*nrows + row)*ncols;
253
254    float tmp = 0.0f; // partial sum for thread in warp
255
256    for (int col = tid; col < ncols; col += block_size) {
257        const float xi = x[col];
258        tmp += xi * xi;
259    }
260
261    // sum up partial sums
262    extern __shared__ float s_sum[];
263    tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
264
265    // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
266    const float scale = rsqrtf(fmaxf(tmp, eps * eps));
267
268    for (int col = tid; col < ncols; col += block_size) {
269        dst[col] = scale * x[col];
270    }
271}
272
273static void norm_f32_cuda(
274        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
275        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
276    const dim3 blocks_num(nrows, nchannels, nsamples);
277    if (ncols < 1024) {
278        const dim3 block_dims(WARP_SIZE, 1, 1);
279        norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
280    } else {
281        const dim3 block_dims(1024, 1, 1);
282        norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
283    }
284}
285
286static void group_norm_f32_cuda(
287        const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
288    if (group_size < 1024) {
289        const dim3 block_dims(WARP_SIZE, 1, 1);
290        group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
291    } else {
292        const dim3 block_dims(1024, 1, 1);
293        group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
294    }
295}
296
297static void rms_norm_f32_cuda(
298        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
299        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
300    const dim3 blocks_num(nrows, nchannels, nsamples);
301    if (ncols < 1024) {
302        const dim3 block_dims(256, 1, 1);
303        rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
304    } else {
305        const dim3 block_dims(1024, 1, 1);
306        rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
307    }
308}
309
310static void rms_norm_mul_f32_cuda(const float *  x,
311                                  const float *  mul,
312                                  const float *  add,
313                                  float *        dst,
314                                  const int      ncols,
315                                  const int      nrows,
316                                  const int      nchannels,
317                                  const int      nsamples,
318                                  const int64_t  stride_row,
319                                  const int64_t  stride_channel,
320                                  const int64_t  stride_sample,
321                                  const int64_t  mul_stride_row,
322                                  const int64_t  mul_stride_channel,
323                                  const int64_t  mul_stride_sample,
324                                  const uint32_t mul_ncols,
325                                  const uint32_t mul_nrows,
326                                  const uint32_t mul_nchannels,
327                                  const uint32_t mul_nsamples,
328                                  const int64_t  add_stride_row,
329                                  const int64_t  add_stride_channel,
330                                  const int64_t  add_stride_sample,
331                                  const uint32_t add_ncols,
332                                  const uint32_t add_nrows,
333                                  const uint32_t add_nchannels,
334                                  const uint32_t add_nsamples,
335                                  const float    eps,
336                                  cudaStream_t   stream) {
337    const dim3 blocks_num(nrows, nchannels, nsamples);
338    if (mul == nullptr) {
339        rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
340        return;
341    }
342    if (add == nullptr) {
343        const uint3 mul_ncols_packed     = init_fastdiv_values(mul_ncols);
344        const uint3 mul_nrows_packed     = init_fastdiv_values(mul_nrows);
345        const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
346        const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);
347        if (ncols < 1024) {
348            const dim3 block_dims(256, 1, 1);
349            rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
350                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
351                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
352        } else {
353            const dim3 block_dims(1024, 1, 1);
354            rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
355                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
356                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
357        }
358    } else {
359        const uint3 mul_ncols_packed     = init_fastdiv_values(mul_ncols);
360        const uint3 mul_nrows_packed     = init_fastdiv_values(mul_nrows);
361        const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
362        const uint3 mul_nsamples_packed  = init_fastdiv_values(mul_nsamples);
363
364        const uint3 add_ncols_packed     = init_fastdiv_values(add_ncols);
365        const uint3 add_nrows_packed     = init_fastdiv_values(add_nrows);
366        const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
367        const uint3 add_nsamples_packed  = init_fastdiv_values(add_nsamples);
368        if (ncols < 1024) {
369            const dim3 block_dims(256, 1, 1);
370            rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
371                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
372                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
373                add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
374                add_nchannels_packed, add_nsamples_packed);
375        } else {
376            const dim3 block_dims(1024, 1, 1);
377            rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
378                x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
379                mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
380                add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
381                add_nchannels_packed, add_nsamples_packed);
382        }
383    }
384}
385
386static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
387    if (ncols < 1024) {
388        const dim3 block_dims(WARP_SIZE, 1, 1);
389        rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
390    } else {
391        const dim3 block_dims(1024, 1, 1);
392        rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
393    }
394}
395
396static void l2_norm_f32_cuda(
397        const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
398        const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
399    const dim3 blocks_num(nrows, nchannels, nsamples);
400    if (ncols < 1024) {
401        const dim3 block_dims(WARP_SIZE, 1, 1);
402        l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
403    } else {
404        const dim3 block_dims(1024, 1, 1);
405        l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
406    }
407}
408
409void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
410    const ggml_tensor * src0 = dst->src[0];
411    const float * src0_d = (const float *) src0->data;
412    float * dst_d = (float *) dst->data;
413    cudaStream_t stream = ctx.stream();
414
415    GGML_ASSERT(src0->type == GGML_TYPE_F32);
416    GGML_ASSERT( dst->type == GGML_TYPE_F32);
417
418    GGML_TENSOR_UNARY_OP_LOCALS;
419
420    float eps;
421    memcpy(&eps, dst->op_params, sizeof(float));
422    GGML_ASSERT(eps >= 0.0f);
423
424    const size_t ts0 = ggml_type_size(src0->type);
425    GGML_ASSERT(nb00 == ts0);
426    const int64_t s01 = nb01 / ts0;
427    const int64_t s02 = nb02 / ts0;
428    const int64_t s03 = nb03 / ts0;
429
430    norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
431}
432
433void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
434    const ggml_tensor * src0 = dst->src[0];
435    const float * src0_d = (const float *)src0->data;
436    float * dst_d = (float *)dst->data;
437    cudaStream_t stream = ctx.stream();
438
439    GGML_ASSERT(src0->type == GGML_TYPE_F32);
440    GGML_ASSERT( dst->type == GGML_TYPE_F32);
441
442    int num_groups = dst->op_params[0];
443
444    float eps;
445    memcpy(&eps, dst->op_params + 1, sizeof(float));
446    GGML_ASSERT(eps >= 0.0f);
447
448    int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
449    group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
450}
451
452void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
453    const ggml_tensor * src0 = dst->src[0];
454    const float * src0_d = (const float *) src0->data;
455    float * dst_d = (float *) dst->data;
456    cudaStream_t stream = ctx.stream();
457
458    GGML_ASSERT(src0->type == GGML_TYPE_F32);
459    GGML_ASSERT( dst->type == GGML_TYPE_F32);
460
461    GGML_TENSOR_UNARY_OP_LOCALS;
462
463    float eps;
464    memcpy(&eps, dst->op_params, sizeof(float));
465    GGML_ASSERT(eps >= 0.0f);
466
467    const size_t ts0 = ggml_type_size(src0->type);
468    GGML_ASSERT(nb00 == ts0);
469    const int64_t s01 = nb01 / ts0;
470    const int64_t s02 = nb02 / ts0;
471    const int64_t s03 = nb03 / ts0;
472
473    rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
474}
475
476void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
477    const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
478    float eps = 0.0f;
479
480    memcpy(&eps, dst->op_params, sizeof(float));
481
482    const float * src0_d = (const float *) rms_norm_src->data;
483    const float * mul_d = nullptr;
484    const ggml_tensor * mul_src = nullptr;
485
486    if (mul_tensor->src[0] == dst) {
487        mul_d = (float *) mul_tensor->src[1]->data;
488        mul_src = mul_tensor->src[1];
489    } else if(mul_tensor->src[1] == dst) {
490        mul_d = (float *) mul_tensor->src[0]->data;
491        mul_src = mul_tensor->src[0];
492    } else {
493        GGML_ASSERT(false);
494    }
495
496    float * dst_d = (float *) mul_tensor->data;
497    cudaStream_t stream = ctx.stream();
498
499    GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
500    GGML_ASSERT(dst->type == GGML_TYPE_F32);
501    GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
502    GGML_ASSERT(eps >= 0.0f);
503
504    const int64_t ne00 = rms_norm_src->ne[0];
505    const int64_t ne01 = rms_norm_src->ne[1];
506    const int64_t ne02 = rms_norm_src->ne[2];
507    const int64_t ne03 = rms_norm_src->ne[3];
508
509    const size_t ts0 = ggml_type_size(rms_norm_src->type);
510    GGML_ASSERT(rms_norm_src->nb[0] == ts0);
511    const int64_t s01 = rms_norm_src->nb[1] / ts0;
512    const int64_t s02 = rms_norm_src->nb[2] / ts0;
513    const int64_t s03 = rms_norm_src->nb[3] / ts0;
514
515    const size_t ts_mul = ggml_type_size(mul_src->type);
516    GGML_ASSERT(mul_src->nb[0] == ts_mul);
517    const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
518    const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
519    const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
520
521    const int mul_ncols     = mul_src->ne[0];
522    const int mul_nrows     = mul_src->ne[1];
523    const int mul_nchannels = mul_src->ne[2];
524    const int mul_nsamples  = mul_src->ne[3];
525
526    rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
527                          ne00, ne01, ne02, ne03,
528                          /*s00*/ s01, s02, s03,
529                          /*mul_s00*/ mul_s01, mul_s02, mul_s03,
530                          mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
531                          /*add_s00*/ 0, 0, 0,
532                          0, 0, 0, 0,
533                          eps, stream);
534}
535
536void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
537                                     ggml_tensor *               dst,
538                                     ggml_tensor *               mul_tensor,
539                                     ggml_tensor *               add_tensor) {
540    const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
541    float               eps          = 0.0f;
542
543    memcpy(&eps, dst->op_params, sizeof(float));
544
545    const float *       src0_d  = (const float *) rms_norm_src->data;
546    const float *       mul_d   = nullptr;
547    const ggml_tensor * mul_src = nullptr;
548
549    if (mul_tensor->src[0] == dst) {
550        mul_d   = (float *) mul_tensor->src[1]->data;
551        mul_src = mul_tensor->src[1];
552    } else if (mul_tensor->src[1] == dst) {
553        mul_d   = (float *) mul_tensor->src[0]->data;
554        mul_src = mul_tensor->src[0];
555    } else {
556        GGML_ASSERT(false);
557    }
558
559    const float *       add_d   = nullptr;
560    const ggml_tensor * add_src = nullptr;
561
562    if (add_tensor->src[0] == mul_tensor) {
563        add_d   = (float *) add_tensor->src[1]->data;
564        add_src = add_tensor->src[1];
565    } else if (add_tensor->src[1] == mul_tensor) {
566        add_d   = (float *) add_tensor->src[0]->data;
567        add_src = add_tensor->src[0];
568    } else {
569        GGML_ASSERT(false);
570    }
571
572    float *      dst_d  = (float *) add_tensor->data;
573    cudaStream_t stream = ctx.stream();
574
575    GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
576    GGML_ASSERT(dst->type == GGML_TYPE_F32);
577    GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
578    GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
579    GGML_ASSERT(eps >= 0.0f);
580
581    const int64_t ne00 = rms_norm_src->ne[0];
582    const int64_t ne01 = rms_norm_src->ne[1];
583    const int64_t ne02 = rms_norm_src->ne[2];
584    const int64_t ne03 = rms_norm_src->ne[3];
585
586    const size_t ts0 = ggml_type_size(rms_norm_src->type);
587    GGML_ASSERT(rms_norm_src->nb[0] == ts0);
588    const int64_t s01 = rms_norm_src->nb[1] / ts0;
589    const int64_t s02 = rms_norm_src->nb[2] / ts0;
590    const int64_t s03 = rms_norm_src->nb[3] / ts0;
591
592    const size_t ts_mul = ggml_type_size(mul_src->type);
593    GGML_ASSERT(mul_src->nb[0] == ts_mul);
594    const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
595    const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
596    const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
597
598    const int mul_ncols     = mul_src->ne[0];
599    const int mul_nrows     = mul_src->ne[1];
600    const int mul_nchannels = mul_src->ne[2];
601    const int mul_nsamples  = mul_src->ne[3];
602
603    const size_t ts_add = ggml_type_size(add_src->type);
604    GGML_ASSERT(add_src->nb[0] == ts_add);
605    const int64_t add_s01 = add_src->nb[1] / ts_add;
606    const int64_t add_s02 = add_src->nb[2] / ts_add;
607    const int64_t add_s03 = add_src->nb[3] / ts_add;
608
609    const int add_ncols     = add_src->ne[0];
610    const int add_nrows     = add_src->ne[1];
611    const int add_nchannels = add_src->ne[2];
612    const int add_nsamples  = add_src->ne[3];
613
614    rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
615                          ne00,ne01, ne02, ne03,
616                          /*s00*/ s01, s02, s03,
617                          /*mul_s00*/ mul_s01, mul_s02, mul_s03,
618                          mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
619                          /*add_s00*/ add_s01, add_s02, add_s03,
620                          add_ncols, add_nrows, add_nchannels, add_nsamples,
621                          eps, stream);
622}
623
624void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
625    const ggml_tensor * grad  = dst->src[0]; // gradients
626    const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
627
628    const float * grad_d  = (const float *) grad->data;
629    const float * src0f_d = (const float *) src0f->data;
630    float       * dst_d   = (float       *) dst->data;
631
632    cudaStream_t stream = ctx.stream();
633
634    GGML_ASSERT(ggml_is_contiguous(grad));
635
636    GGML_ASSERT( grad->type == GGML_TYPE_F32);
637    GGML_ASSERT(src0f->type == GGML_TYPE_F32);
638    GGML_ASSERT(  dst->type == GGML_TYPE_F32);
639
640    const int64_t ne00 = src0f->ne[0];
641    const int64_t nrows = ggml_nrows(src0f);
642
643    float eps;
644    memcpy(&eps, dst->op_params, sizeof(float));
645    GGML_ASSERT(eps >= 0.0f);
646
647    rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
648}
649
650void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
651    const ggml_tensor * src0 = dst->src[0];
652    const float * src0_d = (const float *) src0->data;
653    float * dst_d = (float *) dst->data;
654    cudaStream_t stream = ctx.stream();
655
656    GGML_ASSERT(src0->type == GGML_TYPE_F32);
657    GGML_ASSERT( dst->type == GGML_TYPE_F32);
658
659    GGML_TENSOR_UNARY_OP_LOCALS;
660
661    float eps;
662    memcpy(&eps, dst->op_params, sizeof(float));
663    GGML_ASSERT(eps >= 0.0f);
664
665    const size_t ts0 = ggml_type_size(src0->type);
666    GGML_ASSERT(nb00 == ts0);
667    const int64_t s01 = nb01 / ts0;
668    const int64_t s02 = nb02 / ts0;
669    const int64_t s03 = nb03 / ts0;
670
671    l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
672}