1#include "ggml.h"
  2#include "mmf.cuh"
  3#include "mmid.cuh"
  4
  5static __forceinline__ int mmf_get_rows_per_block(const int cc) {
  6    if (GGML_CUDA_CC_IS_CDNA(cc)) {
  7        return MMF_ROWS_PER_BLOCK_CDNA;
  8    } else {
  9        return MMF_ROWS_PER_BLOCK;
 10    }
 11}
 12
 13void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
 14    GGML_ASSERT(        src1->type == GGML_TYPE_F32);
 15    GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
 16    GGML_ASSERT(         dst->type == GGML_TYPE_F32);
 17
 18
 19    GGML_TENSOR_BINARY_OP_LOCALS;
 20
 21    const size_t ts_src0 = ggml_type_size(src0->type);
 22    const size_t ts_src1 = ggml_type_size(src1->type);
 23    const size_t ts_dst  = ggml_type_size(dst->type);
 24
 25    GGML_ASSERT(ne13 == ne3);
 26
 27    GGML_ASSERT(        nb00       == ts_src0);
 28    GGML_ASSERT(        nb10       == ts_src1);
 29    GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
 30    GGML_ASSERT(        nb0        == ts_dst);
 31
 32    const float   * src1_d =       (const float   *) src1->data;
 33    const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
 34    float         *  dst_d =       (float         *)  dst->data;
 35
 36    const int64_t s01 = src0->nb[1] / ts_src0;
 37    const int64_t s11 = src1->nb[1] / ts_src1;
 38    const int64_t s1  =  dst->nb[1] / ts_dst;
 39    const int64_t s02 = src0->nb[2] / ts_src0;
 40    const int64_t s12 = src1->nb[2] / ts_src1;
 41    const int64_t s2  =  dst->nb[2] / ts_dst;
 42    const int64_t s03 = src0->nb[3] / ts_src0;
 43    const int64_t s13 = src1->nb[3] / ts_src1;
 44    const int64_t s3  =  dst->nb[3] / ts_dst;
 45
 46    const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
 47    const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
 48
 49    mmf_ids_data ids_info{};
 50    mmf_ids_data * ids_info_ptr = nullptr;
 51    ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
 52    ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
 53    ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
 54
 55    // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
 56    const int64_t ncols_dst          = ids ? ne2  : ne1;
 57    const int64_t nchannels_dst      = ids ? ne1 : ne2;
 58
 59    const int64_t stride_col_dst     = ids ? s2   : s1;
 60    const int64_t stride_col_y       = ids ? s12  : s11;
 61    const int64_t stride_channel_dst = ids ? s1 : s2;
 62
 63    int64_t stride_channel_y         = ids ? s11  : s12;
 64    int64_t nchannels_y              = ids ? ne11 : ne12;
 65
 66    //mul_mat_id: handle broadcast
 67    if (ids && nchannels_y == 1) {
 68        stride_channel_y = 0;
 69        nchannels_y      = ids->ne[0];
 70    }
 71
 72    if (ids && ncols_dst > 16) {
 73        const int64_t n_expert_used = ids->ne[0];
 74        const int64_t n_experts     = ne02;
 75        const int64_t n_tokens      = ne12;
 76        const int64_t ne_get_rows   = n_tokens * n_expert_used;
 77
 78        ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
 79        ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
 80        expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
 81
 82        const int si1  = static_cast<int>(ids_s1);
 83        const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
 84
 85        GGML_ASSERT(sis1 > 0);
 86
 87        ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
 88            static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
 89        CUDA_CHECK(cudaGetLastError());
 90
 91        ids_info.ids_src_compact   = ids_src_compact_dev.get();
 92        ids_info.ids_dst_compact   = ids_dst_compact_dev.get();
 93        ids_info.expert_bounds_dev = expert_bounds_dev.get();
 94        ids_info.n_experts         = static_cast<int>(n_experts);
 95        ids_info.sis1              = sis1;
 96        ids_info_ptr = &ids_info;
 97    }
 98
 99    const int device    = ggml_cuda_get_device();
100    const int cc        = ggml_cuda_info().devices[device].cc;
101    const int rows_per_block = mmf_get_rows_per_block(cc);
102
103    switch (src0->type) {
104        case GGML_TYPE_F32: {
105            const float * src0_d = (const float *) src0->data;
106            constexpr int vals_per_T = 1;
107            mul_mat_f_switch_rows_per_block<float>(
108                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
109                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
110                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
111        } break;
112        case GGML_TYPE_F16: {
113            const half2 * src0_d = (const half2 *) src0->data;
114            constexpr int vals_per_T = 2;
115            mul_mat_f_switch_rows_per_block<half2>(
116                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
117                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
118                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
119        } break;
120        case GGML_TYPE_BF16: {
121            const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
122            constexpr int vals_per_T = 2;
123            mul_mat_f_switch_rows_per_block<nv_bfloat162>(
124                rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
125                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
126                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
127        } break;
128        default:
129            GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
130    }
131}
132
133bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne,
134        const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) {
135    if (ggml_is_quantized(type)) {
136        return false;
137    }
138
139    const size_t ts = ggml_type_size(type);
140    if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
141        return false;
142    }
143
144    if (src0_nb[0] != ts) {
145        return false;
146    }
147
148    // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
149    for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
150        if (src0_nb[i] % (2*ts) != 0) {
151            return false;
152        }
153    }
154    if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
155        return false;
156    }
157
158    if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
159        return false;
160    }
161
162    if (mul_mat_id) {
163        if (src0_ne[1] <= 1024 && src1_ncols > 512) {
164            return false;
165        } else if(src0_ne[1] > 1024 && src1_ncols > 128) {
166            return false;
167        }
168    } else {
169        if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
170            return false;
171        } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
172            //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
173            return false;
174        } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
175            return false;
176        } else if (src1_ncols > 16) {
177            return false;
178        }
179    }
180
181    switch (type) {
182        case GGML_TYPE_F32:
183            return ampere_mma_available(cc) || amd_mfma_available(cc);
184        case GGML_TYPE_F16:
185            return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
186        case GGML_TYPE_BF16:
187            return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
188        default:
189            return false;
190    }
191}