1#include "common.cuh"
 2
 3// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
 4template <bool norm>
 5static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
 6    const int row = blockIdx.x;
 7    const int col = threadIdx.x;
 8
 9    float     sum        = 0.0f;
10    const int num_unroll = 8;
11    float     temp[num_unroll];
12    float     sum_temp[num_unroll] = { 0.0f };
13    for (int i = col; i < ncols;) {
14        for (int j = 0; j < num_unroll; ++j) {
15            if (i < ncols) {
16                temp[j] = x[row * ncols + i];
17            } else {
18                temp[j] = 0;
19            }
20            i += blockDim.x;
21        }
22        for (int j = 0; j < num_unroll; ++j) {
23            sum_temp[j] += temp[j];
24        }
25    }
26    for (int j = 0; j < num_unroll; ++j) {
27        sum += sum_temp[j];
28    }
29
30    // sum up partial sums
31    __shared__ float shared_vals[32];
32    sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);
33
34    if (col != 0) {
35        return;
36    }
37
38    dst[row] = norm ? sum / ncols : sum;
39}