1#include "common.cuh"
  2#include "mmid.cuh"
  3
  4// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
  5struct mm_ids_helper_store {
  6    uint32_t data;
  7
  8    __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
  9        data = (it & 0x003FFFFF) | (iex_used << 22);
 10    }
 11
 12    __device__ uint32_t it() const {
 13        return data & 0x003FFFFF;
 14    }
 15
 16    __device__ uint32_t iex_used() const {
 17        return data >> 22;
 18    }
 19};
 20static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
 21
 22// Helper function for mul_mat_id, converts ids to a more convenient format.
 23// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
 24// ids_dst describes the same mapping but for the dst tensor.
 25// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
 26template <int n_expert_used_template>
 27__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
 28static __global__ void mm_ids_helper(
 29        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
 30        const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
 31    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 32    const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
 33    const int expert = blockIdx.x;
 34
 35    extern __shared__ char data_mm_ids_helper[];
 36    mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
 37
 38    int nex_prev   = 0; // Number of columns for experts with a lower index.
 39    int it_compact = 0; // Running index for the compact slice of this expert.
 40
 41    if constexpr (n_expert_used_template == 0) {
 42        // Generic implementation:
 43        for (int it = 0; it < n_tokens; ++it) {
 44            int iex_used = -1; // The index at which the expert is used, if any.
 45            for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
 46                const int expert_used = ids[it*si1 + iex];
 47                nex_prev += expert_used < expert;
 48                if (expert_used == expert) {
 49                    iex_used = iex;
 50                }
 51            }
 52
 53            if (iex_used != -1) {
 54                store[it_compact] = mm_ids_helper_store(it, iex_used);
 55            }
 56
 57            if (warp_reduce_any<warp_size>(iex_used != -1)) {
 58                it_compact++;
 59            }
 60        }
 61    } else {
 62        // Implementation optimized for specific numbers of experts used:
 63        static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
 64        const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
 65        for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
 66            const int it = it0 + threadIdx.x / neu_padded;
 67
 68            const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
 69            const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
 70                ids[it*si1 + iex] : INT_MAX;
 71            const int iex_used = expert_used == expert ? iex : -1;
 72            nex_prev += expert_used < expert;
 73
 74            // Whether the threads at this token position have used the expert:
 75            const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
 76
 77            // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
 78            int it_compact_add_lower = 0;
 79#pragma unroll
 80            for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
 81                const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
 82                if (threadIdx.x >= static_cast<unsigned int>(offset)) {
 83                    it_compact_add_lower += tmp;
 84                }
 85            }
 86
 87            if (iex_used != -1) {
 88                store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
 89            }
 90
 91            // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
 92            it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
 93        }
 94    }
 95    nex_prev = warp_reduce_sum<warp_size>(nex_prev);
 96
 97    for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
 98        const mm_ids_helper_store store_it = store[itc];
 99        const int it       = store_it.it();
100        const int iex_used = store_it.iex_used();
101        ids_src1[nex_prev + itc] = it*sis1          + iex_used % nchannels_y;
102        ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
103    }
104
105    if (threadIdx.x != 0) {
106        return;
107    }
108
109    expert_bounds[expert] = nex_prev;
110
111    if (expert < static_cast<int>(gridDim.x) - 1) {
112        return;
113    }
114
115    expert_bounds[gridDim.x] = nex_prev + it_compact;
116}
117
118template <int n_expert_used_template>
119static void launch_mm_ids_helper(
120        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
121        const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
122    GGML_ASSERT(n_tokens          < (1 << 22) && "too few bits in mm_ids_helper_store");
123    GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
124
125    const int id = ggml_cuda_get_device();
126    const int warp_size = ggml_cuda_info().devices[id].warp_size;
127    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
128    CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
129
130    const dim3 num_blocks(n_experts, 1, 1);
131    const dim3 block_size(warp_size, 1, 1);
132    const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
133    GGML_ASSERT(nbytes_shared <= smpbo);
134    mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
135        (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
136}
137
138void ggml_cuda_launch_mm_ids_helper(
139        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
140        const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
141    switch (n_expert_used) {
142        case  2:
143            launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
144            break;
145        case  4:
146            launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
147            break;
148        case  6:
149            launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
150            break;
151        case  8:
152            launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
153            break;
154        case 16:
155            launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
156            break;
157        case 32:
158            launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
159            break;
160        default:
161            launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
162            break;
163    }
164}