llama.cpp
.devops
nix
apps.nix devshells.nix docker.nix jetson-support.nix nixpkgs-instances.nix package-gguf-py.nix package.nix python-scripts.nix scope.nix sif.nix.github
ISSUE_TEMPLATE
010-bug-compilation.yml 011-bug-results.yml 019-bug-misc.yml 020-enhancement.yml 030-research.yml 040-refactor.yml config.ymlworkflows
bench.yml.disabled build-cache.yml build-cmake-pkg.yml build-linux-cross.yml build.yml check-vendor.yml close-issue.yml copilot-setup-steps.yml docker.yml editorconfig.yml gguf-publish.yml labeler.yml pre-tokenizer-hashes.yml python-check-requirements.yml python-lint.yml python-type-check.yml release.yml server-metal.yml server-webui.yml server.yml update-ops-docs.yml winget.ymlbenches
cmake
arm64-apple-clang.cmake arm64-windows-llvm.cmake build-info.cmake common.cmake download-models.cmake git-vars.cmake license.cmake llama-config.cmake.in llama.pc.in riscv64-spacemit-linux-gnu-gcc.cmake x64-windows-llvm.cmakecommon
jinja
README.md caps.cpp caps.h lexer.cpp lexer.h parser.cpp parser.h runtime.cpp runtime.h string.cpp string.h utils.h value.cpp value.hdocs
multimodal
MobileVLM.md gemma3.md glmedge.md granitevision.md llava.md minicpmo2.6.md minicpmo4.0.md minicpmv2.5.md minicpmv2.6.md minicpmv4.0.md minicpmv4.5.mdops
BLAS.csv CANN.csv CPU.csv CUDA.csv Metal.csv OpenCL.csv SYCL.csv Vulkan.csv WebGPU.csv ZenDNN.csv zDNN.csvexamples
llama.android
app
src
lib
.gitignore build.gradle.kts consumer-rules.pro proguard-rules.promodel-conversion
scripts
causal
compare-embeddings-logits.sh compare-logits.py convert-model.sh modelcard.template run-casual-gen-embeddings-org.py run-converted-model-embeddings-logits.sh run-converted-model.sh run-org-model.pyembedding
compare-embeddings-logits.sh convert-model.sh modelcard.template run-converted-model.sh run-original-model.pyutils
__init__.py check-nmse.py common.py compare_tokens.py create-collection-add-model.sh curl-embedding-server.sh hf-add-model-to-collection.py hf-create-collection.py hf-create-model.py hf-upload-gguf-model.py inspect-converted-model.sh inspect-org-model.py perplexity-gen.sh perplexity-run-simple.sh perplexity-run.sh quantize.sh run-embedding-server.sh semantic_check.py tensor-info.pysycl
CMakeLists.txt README.md build.sh ls-sycl-device.cpp run-llama2.sh test.sh win-build-sycl.bat win-run-llama2.bat win-test.batggml
include
ggml-alloc.h ggml-backend.h ggml-blas.h ggml-cann.h ggml-cpp.h ggml-cpu.h ggml-cuda.h ggml-hexagon.h ggml-metal.h ggml-opencl.h ggml-opt.h ggml-rpc.h ggml-sycl.h ggml-virtgpu.h ggml-vulkan.h ggml-webgpu.h ggml-zdnn.h ggml-zendnn.h ggml.h gguf.hsrc
ggml-cann
CMakeLists.txt acl_tensor.cpp acl_tensor.h aclnn_ops.cpp aclnn_ops.h common.h ggml-cann.cppggml-cpu
CMakeLists.txt arch-fallback.h binary-ops.cpp binary-ops.h common.h ggml-cpu-impl.h ggml-cpu.c ggml-cpu.cpp hbm.cpp hbm.h ops.cpp ops.h quants.c quants.h repack.cpp repack.h simd-mappings.h traits.cpp traits.h unary-ops.cpp unary-ops.h vec.cpp vec.hggml-cuda
template-instances
fattn-mma-f16-instance-ncols1_1-ncols2_16.cu fattn-mma-f16-instance-ncols1_1-ncols2_32.cu fattn-mma-f16-instance-ncols1_1-ncols2_8.cu fattn-mma-f16-instance-ncols1_16-ncols2_1.cu fattn-mma-f16-instance-ncols1_16-ncols2_2.cu fattn-mma-f16-instance-ncols1_16-ncols2_4.cu fattn-mma-f16-instance-ncols1_2-ncols2_16.cu fattn-mma-f16-instance-ncols1_2-ncols2_32.cu fattn-mma-f16-instance-ncols1_2-ncols2_4.cu fattn-mma-f16-instance-ncols1_2-ncols2_8.cu fattn-mma-f16-instance-ncols1_32-ncols2_1.cu fattn-mma-f16-instance-ncols1_32-ncols2_2.cu fattn-mma-f16-instance-ncols1_4-ncols2_16.cu fattn-mma-f16-instance-ncols1_4-ncols2_2.cu fattn-mma-f16-instance-ncols1_4-ncols2_4.cu fattn-mma-f16-instance-ncols1_4-ncols2_8.cu fattn-mma-f16-instance-ncols1_64-ncols2_1.cu fattn-mma-f16-instance-ncols1_8-ncols2_1.cu fattn-mma-f16-instance-ncols1_8-ncols2_2.cu fattn-mma-f16-instance-ncols1_8-ncols2_4.cu fattn-mma-f16-instance-ncols1_8-ncols2_8.cu fattn-tile-instance-dkq112-dv112.cu fattn-tile-instance-dkq128-dv128.cu fattn-tile-instance-dkq256-dv256.cu fattn-tile-instance-dkq40-dv40.cu fattn-tile-instance-dkq576-dv512.cu fattn-tile-instance-dkq64-dv64.cu fattn-tile-instance-dkq72-dv72.cu fattn-tile-instance-dkq80-dv80.cu fattn-tile-instance-dkq96-dv96.cu fattn-vec-instance-f16-f16.cu fattn-vec-instance-f16-q4_0.cu fattn-vec-instance-f16-q4_1.cu fattn-vec-instance-f16-q5_0.cu fattn-vec-instance-f16-q5_1.cu fattn-vec-instance-f16-q8_0.cu fattn-vec-instance-q4_0-f16.cu fattn-vec-instance-q4_0-q4_0.cu fattn-vec-instance-q4_0-q4_1.cu fattn-vec-instance-q4_0-q5_0.cu fattn-vec-instance-q4_0-q5_1.cu fattn-vec-instance-q4_0-q8_0.cu fattn-vec-instance-q4_1-f16.cu fattn-vec-instance-q4_1-q4_0.cu fattn-vec-instance-q4_1-q4_1.cu fattn-vec-instance-q4_1-q5_0.cu fattn-vec-instance-q4_1-q5_1.cu fattn-vec-instance-q4_1-q8_0.cu fattn-vec-instance-q5_0-f16.cu fattn-vec-instance-q5_0-q4_0.cu fattn-vec-instance-q5_0-q4_1.cu fattn-vec-instance-q5_0-q5_0.cu fattn-vec-instance-q5_0-q5_1.cu fattn-vec-instance-q5_0-q8_0.cu fattn-vec-instance-q5_1-f16.cu fattn-vec-instance-q5_1-q4_0.cu fattn-vec-instance-q5_1-q4_1.cu fattn-vec-instance-q5_1-q5_0.cu fattn-vec-instance-q5_1-q5_1.cu fattn-vec-instance-q5_1-q8_0.cu fattn-vec-instance-q8_0-f16.cu fattn-vec-instance-q8_0-q4_0.cu fattn-vec-instance-q8_0-q4_1.cu fattn-vec-instance-q8_0-q5_0.cu fattn-vec-instance-q8_0-q5_1.cu fattn-vec-instance-q8_0-q8_0.cu generate_cu_files.py mmf-instance-ncols_1.cu mmf-instance-ncols_10.cu mmf-instance-ncols_11.cu mmf-instance-ncols_12.cu mmf-instance-ncols_13.cu mmf-instance-ncols_14.cu mmf-instance-ncols_15.cu mmf-instance-ncols_16.cu mmf-instance-ncols_2.cu mmf-instance-ncols_3.cu mmf-instance-ncols_4.cu mmf-instance-ncols_5.cu mmf-instance-ncols_6.cu mmf-instance-ncols_7.cu mmf-instance-ncols_8.cu mmf-instance-ncols_9.cu mmq-instance-iq1_s.cu mmq-instance-iq2_s.cu mmq-instance-iq2_xs.cu mmq-instance-iq2_xxs.cu mmq-instance-iq3_s.cu mmq-instance-iq3_xxs.cu mmq-instance-iq4_nl.cu mmq-instance-iq4_xs.cu mmq-instance-mxfp4.cu mmq-instance-q2_k.cu mmq-instance-q3_k.cu mmq-instance-q4_0.cu mmq-instance-q4_1.cu mmq-instance-q4_k.cu mmq-instance-q5_0.cu mmq-instance-q5_1.cu mmq-instance-q5_k.cu mmq-instance-q6_k.cu mmq-instance-q8_0.cuggml-hexagon
htp
CMakeLists.txt act-ops.c argsort-ops.c binary-ops.c cmake-toolchain.cmake cpy-ops.c flash-attn-ops.c get-rows-ops.c hex-dma.c hex-dma.h hex-dump.h hex-fastdiv.h hex-utils.h htp-ctx.h htp-msg.h htp-ops.h htp_iface.idl hvx-arith.h hvx-base.h hvx-copy.h hvx-div.h hvx-dump.h hvx-exp.h hvx-floor.h hvx-inverse.h hvx-reduce.h hvx-scale.h hvx-sigmoid.h hvx-sqrt.h hvx-types.h hvx-utils.h main.c matmul-ops.c rope-ops.c set-rows-ops.c softmax-ops.c sum-rows-ops.c unary-ops.c worker-pool.c worker-pool.hggml-metal
CMakeLists.txt ggml-metal-common.cpp ggml-metal-common.h ggml-metal-context.h ggml-metal-context.m ggml-metal-device.cpp ggml-metal-device.h ggml-metal-device.m ggml-metal-impl.h ggml-metal-ops.cpp ggml-metal-ops.h ggml-metal.cpp ggml-metal.metalggml-opencl
kernels
add.cl add_id.cl argsort.cl clamp.cl concat.cl conv2d.cl conv2d_f16_f32.cl cpy.cl cvt.cl diag_mask_inf.cl div.cl embed_kernel.py expm1.cl fill.cl flash_attn_f16.cl flash_attn_f32.cl flash_attn_f32_f16.cl gelu.cl gemm_moe_mxfp4_f32.cl gemv_moe_mxfp4_f32.cl gemv_noshuffle.cl gemv_noshuffle_general.cl gemv_noshuffle_general_q8_0_f32.cl get_rows.cl glu.cl group_norm.cl im2col_f16.cl im2col_f32.cl mean.cl mul.cl mul_mat_Ab_Bi_8x4.cl mul_mat_f16_f32.cl mul_mm_f16_f32_kq_kqv.cl mul_mm_f16_f32_l4_lm.cl mul_mm_f32_f32_l4_lm.cl mul_mm_q6_k_f32_l4_lm.cl mul_mm_q8_0_f32_8x4.cl mul_mm_q8_0_f32_l4_lm.cl mul_mv_f16_f16.cl mul_mv_f16_f32.cl mul_mv_f16_f32_1row.cl mul_mv_f16_f32_l4.cl mul_mv_f32_f32.cl mul_mv_id_mxfp4_f32.cl mul_mv_id_mxfp4_f32_flat.cl mul_mv_id_q4_0_f32_8x_flat.cl mul_mv_id_q8_0_f32.cl mul_mv_id_q8_0_f32_flat.cl mul_mv_mxfp4_f32.cl mul_mv_mxfp4_f32_flat.cl mul_mv_q4_0_f32.cl mul_mv_q4_0_f32_1d_16x_flat.cl mul_mv_q4_0_f32_1d_8x_flat.cl mul_mv_q4_0_f32_8x_flat.cl mul_mv_q4_0_f32_v.cl mul_mv_q4_k_f32.cl mul_mv_q6_k_f32.cl mul_mv_q6_k_f32_flat.cl mul_mv_q8_0_f32.cl mul_mv_q8_0_f32_flat.cl norm.cl pad.cl relu.cl repeat.cl rms_norm.cl rope.cl scale.cl set_rows.cl sigmoid.cl silu.cl softmax_4_f16.cl softmax_4_f32.cl softmax_f16.cl softmax_f32.cl softplus.cl solve_tri.cl sqr.cl sqrt.cl ssm_conv.cl sub.cl sum_rows.cl tanh.cl transpose.cl tri.cl tsembd.cl upscale.clggml-sycl
CMakeLists.txt add-id.cpp add-id.hpp backend.hpp binbcast.cpp binbcast.hpp common.cpp common.hpp concat.cpp concat.hpp conv.cpp conv.hpp convert.cpp convert.hpp count-equal.cpp count-equal.hpp cpy.cpp cpy.hpp dequantize.hpp dmmv.cpp dmmv.hpp element_wise.cpp element_wise.hpp gemm.hpp getrows.cpp getrows.hpp ggml-sycl.cpp gla.cpp gla.hpp im2col.cpp im2col.hpp mmq.cpp mmq.hpp mmvq.cpp mmvq.hpp norm.cpp norm.hpp outprod.cpp outprod.hpp pad.cpp pad.hpp pad_reflect_1d.cpp pad_reflect_1d.hpp presets.hpp quantize.hpp quants.hpp repeat_back.cpp repeat_back.hpp roll.cpp roll.hpp rope.cpp rope.hpp set.cpp set.hpp set_rows.cpp set_rows.hpp softmax.cpp softmax.hpp ssm_conv.cpp ssm_conv.hpp sycl_hw.cpp sycl_hw.hpp tsembd.cpp tsembd.hpp vecdotq.hpp wkv.cpp wkv.hppggml-virtgpu
backend
CMakeLists.txt apir_cs_ggml-rpc-back.cpp backend-convert.h backend-dispatched-backend.cpp backend-dispatched-buffer-type.cpp backend-dispatched-buffer.cpp backend-dispatched-device.cpp backend-dispatched.cpp backend-dispatched.gen.h backend-dispatched.h backend-virgl-apir.h backend.cppggml-vulkan
vulkan-shaders
CMakeLists.txt abs.comp acc.comp add.comp add1.comp add_id.comp arange.comp argmax.comp argsort.comp argsort_large.comp ceil.comp clamp.comp concat.comp contig_copy.comp conv2d_dw.comp conv2d_mm.comp conv_transpose_1d.comp copy.comp copy_from_quant.comp copy_to_quant.comp copy_transpose.comp cos.comp count_equal.comp count_experts.comp cumsum.comp cumsum_multipass1.comp cumsum_multipass2.comp dequant_f32.comp dequant_funcs.glsl dequant_funcs_cm2.glsl dequant_head.glsl dequant_iq1_m.comp dequant_iq1_s.comp dequant_iq2_s.comp dequant_iq2_xs.comp dequant_iq2_xxs.comp dequant_iq3_s.comp dequant_iq3_xxs.comp dequant_iq4_nl.comp dequant_iq4_xs.comp dequant_mxfp4.comp dequant_q2_k.comp dequant_q3_k.comp dequant_q4_0.comp dequant_q4_1.comp dequant_q4_k.comp dequant_q5_0.comp dequant_q5_1.comp dequant_q5_k.comp dequant_q6_k.comp dequant_q8_0.comp diag.comp diag_mask_inf.comp div.comp exp.comp fill.comp flash_attn.comp flash_attn_base.glsl flash_attn_cm1.comp flash_attn_cm2.comp flash_attn_mask_opt.comp flash_attn_split_k_reduce.comp floor.comp geglu.comp geglu_erf.comp geglu_quick.comp gelu.comp gelu_erf.comp gelu_quick.comp generic_binary_head.glsl generic_head.glsl generic_unary_head.glsl get_rows.comp get_rows_quant.comp glu_head.glsl glu_main.glsl group_norm.comp hardsigmoid.comp hardswish.comp im2col.comp im2col_3d.comp l2_norm.comp leaky_relu.comp log.comp mul.comp mul_mat_split_k_reduce.comp mul_mat_vec.comp mul_mat_vec_base.glsl mul_mat_vec_iface.glsl mul_mat_vec_iq1_m.comp mul_mat_vec_iq1_s.comp mul_mat_vec_iq2_s.comp mul_mat_vec_iq2_xs.comp mul_mat_vec_iq2_xxs.comp mul_mat_vec_iq3_s.comp mul_mat_vec_iq3_xxs.comp mul_mat_vec_nc.comp mul_mat_vec_p021.comp mul_mat_vec_q2_k.comp mul_mat_vec_q3_k.comp mul_mat_vec_q4_k.comp mul_mat_vec_q5_k.comp mul_mat_vec_q6_k.comp mul_mat_vecq.comp mul_mat_vecq_funcs.glsl mul_mm.comp mul_mm_cm2.comp mul_mm_funcs.glsl mul_mm_id_funcs.glsl mul_mmq.comp mul_mmq_funcs.glsl mul_mmq_shmem_types.glsl multi_add.comp neg.comp norm.comp opt_step_adamw.comp opt_step_sgd.comp pad.comp pool2d.comp quantize_q8_1.comp reglu.comp relu.comp repeat.comp repeat_back.comp rms_norm.comp rms_norm_back.comp rms_norm_partials.comp roll.comp rope_funcs.glsl rope_head.glsl rope_multi.comp rope_neox.comp rope_norm.comp rope_params.glsl rope_vision.comp round.comp rte.glsl scale.comp sigmoid.comp silu.comp silu_back.comp sin.comp soft_max.comp soft_max_back.comp soft_max_large1.comp soft_max_large2.comp soft_max_large3.comp soft_max_large_common.glsl softplus.comp solve_tri.comp sqrt.comp square.comp ssm_conv.comp ssm_scan.comp step.comp sub.comp sum_rows.comp sum_rows.glsl swiglu.comp swiglu_oai.comp tanh.comp timestep_embedding.comp topk_argsort.comp topk_moe.comp topk_nary_search.comp tri.comp trunc.comp types.glsl upscale.comp utils.glsl vulkan-shaders-gen.cpp wkv6.comp wkv7.comp xielu.compggml-webgpu
wgsl-shaders
argmax.wgsl argsort.wgsl argsort_merge.wgsl binary.wgsl common_decls.tmpl cpy.tmpl.wgsl cumsum.wgsl embed_wgsl.py flash_attn.wgsl get_rows.tmpl.wgsl glu.tmpl.wgsl memset.wgsl mul_mat.tmpl.wgsl mul_mat_decls.tmpl mul_mat_reg_tile.tmpl.wgsl mul_mat_subgroup_matrix.tmpl.wgsl mul_mat_vec.tmpl.wgsl pad.wgsl rms_norm.wgsl rope.tmpl.wgsl scale.tmpl.wgsl set_rows.wgsl soft_max.tmpl.wgsl sum_rows.wgsl unary.wgslgguf-py
gguf
scripts
gguf_convert_endian.py gguf_dump.py gguf_editor_gui.py gguf_hash.py gguf_new_metadata.py gguf_set_metadata.pygrammars
README.md arithmetic.gbnf c.gbnf chess.gbnf english.gbnf japanese.gbnf json.gbnf json_arr.gbnf list.gbnfmedia
llama0-banner.png llama0-logo.png llama1-banner.png llama1-icon-transparent.png llama1-icon-transparent.svg llama1-icon.png llama1-icon.svg llama1-logo.png llama1-logo.svg matmul.png matmul.svgmodels
templates
Apertus-8B-Instruct.jinja ByteDance-Seed-OSS.jinja CohereForAI-c4ai-command-r-plus-tool_use.jinja CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja GLM-4.6.jinja Kimi-K2-Instruct.jinja Kimi-K2-Thinking.jinja MiMo-VL.jinja MiniMax-M2.jinja Mistral-Small-3.2-24B-Instruct-2506.jinja NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja NVIDIA-Nemotron-Nano-v2.jinja NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja Qwen-QwQ-32B.jinja Qwen-Qwen2.5-7B-Instruct.jinja Qwen-Qwen3-0.6B.jinja Qwen3-Coder.jinja README.md deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja deepseek-ai-DeepSeek-V3.1.jinja fireworks-ai-llama-3-firefunction-v2.jinja google-gemma-2-2b-it.jinja ibm-granite-granite-3.3-2B-Instruct.jinja llama-cpp-deepseek-r1.jinja llama-cpp-lfm2.jinja llama-cpp-rwkv-world.jinja meetkai-functionary-medium-v3.1.jinja meetkai-functionary-medium-v3.2.jinja meta-llama-Llama-3.1-8B-Instruct.jinja meta-llama-Llama-3.2-3B-Instruct.jinja meta-llama-Llama-3.3-70B-Instruct.jinja microsoft-Phi-3.5-mini-instruct.jinja mistralai-Ministral-3-14B-Reasoning-2512.jinja mistralai-Mistral-Nemo-Instruct-2407.jinja moonshotai-Kimi-K2.jinja openai-gpt-oss-120b.jinja unsloth-Apriel-1.5.jinja unsloth-mistral-Devstral-Small-2507.jinja upstage-Solar-Open-100B.jinjarequirements
requirements-all.txt requirements-compare-llama-bench.txt requirements-convert_hf_to_gguf.txt requirements-convert_hf_to_gguf_update.txt requirements-convert_legacy_llama.txt requirements-convert_llama_ggml_to_gguf.txt requirements-convert_lora_to_gguf.txt requirements-gguf_editor_gui.txt requirements-pydantic.txt requirements-server-bench.txt requirements-test-tokenizer-random.txt requirements-tool_bench.txtscripts
bench-models.sh build-info.sh check-requirements.sh compare-commits.sh compare-llama-bench.py compare-logprobs.py create_ops_docs.py debug-test.sh fetch_server_test_models.py gen-authors.sh gen-unicode-data.py get-flags.mk get-hellaswag.sh get-pg.sh get-wikitext-103.sh get-wikitext-2.sh get-winogrande.sh get_chat_template.py hf.sh install-oneapi.bat pr2wt.sh serve-static.js server-bench.py sync-ggml-am.sh sync-ggml.last sync-ggml.sh sync_vendor.py tool_bench.py tool_bench.sh verify-checksum-models.py xxd.cmakesrc
models
afmoe.cpp apertus.cpp arcee.cpp arctic.cpp arwkv7.cpp baichuan.cpp bailingmoe.cpp bailingmoe2.cpp bert.cpp bitnet.cpp bloom.cpp chameleon.cpp chatglm.cpp codeshell.cpp cogvlm.cpp cohere2-iswa.cpp command-r.cpp dbrx.cpp deci.cpp deepseek.cpp deepseek2.cpp dots1.cpp dream.cpp ernie4-5-moe.cpp ernie4-5.cpp exaone-moe.cpp exaone.cpp exaone4.cpp falcon-h1.cpp falcon.cpp gemma-embedding.cpp gemma.cpp gemma2-iswa.cpp gemma3.cpp gemma3n-iswa.cpp glm4-moe.cpp glm4.cpp gpt2.cpp gptneox.cpp granite-hybrid.cpp granite.cpp graph-context-mamba.cpp grok.cpp grovemoe.cpp hunyuan-dense.cpp hunyuan-moe.cpp internlm2.cpp jais.cpp jamba.cpp kimi-linear.cpp lfm2.cpp llada-moe.cpp llada.cpp llama-iswa.cpp llama.cpp maincoder.cpp mamba.cpp mimo2-iswa.cpp minicpm3.cpp minimax-m2.cpp mistral3.cpp models.h modern-bert.cpp mpt.cpp nemotron-h.cpp nemotron.cpp neo-bert.cpp olmo.cpp olmo2.cpp olmoe.cpp openai-moe-iswa.cpp openelm.cpp orion.cpp pangu-embedded.cpp phi2.cpp phi3.cpp plamo.cpp plamo2.cpp plamo3.cpp plm.cpp qwen.cpp qwen2.cpp qwen2moe.cpp qwen2vl.cpp qwen3.cpp qwen35.cpp qwen35moe.cpp qwen3moe.cpp qwen3next.cpp qwen3vl-moe.cpp qwen3vl.cpp refact.cpp rnd1.cpp rwkv6-base.cpp rwkv6.cpp rwkv6qwen2.cpp rwkv7-base.cpp rwkv7.cpp seed-oss.cpp smallthinker.cpp smollm3.cpp stablelm.cpp starcoder.cpp starcoder2.cpp step35-iswa.cpp t5-dec.cpp t5-enc.cpp wavtokenizer-dec.cpp xverse.cpptests
peg-parser
simple-tokenize.cpp simple-tokenize.h test-basic.cpp test-gbnf-generation.cpp test-json-parser.cpp test-json-serialization.cpp test-unicode.cpp tests.htools
cvector-generator
CMakeLists.txt README.md completions.txt cvector-generator.cpp mean.hpp negative.txt pca.hpp positive.txtmtmd
legacy-models
convert_image_encoder_to_gguf.py glmedge-convert-image-encoder-to-gguf.py glmedge-surgery.py llava_surgery.py llava_surgery_v2.py minicpmv-convert-image-encoder-to-gguf.py minicpmv-surgery.pymodels
cogvlm.cpp conformer.cpp glm4v.cpp internvl.cpp kimik25.cpp kimivl.cpp llama4.cpp llava.cpp minicpmv.cpp mobilenetv5.cpp models.h pixtral.cpp qwen2vl.cpp qwen3vl.cpp siglip.cpp whisper-enc.cpp youtuvl.cppserver
public_legacy
colorthemes.css completion.js favicon.ico index-new.html index.html index.js json-schema-to-grammar.mjs loading.html prompt-formats.js style.css system-prompts.js theme-beeninorder.css theme-ketivah.css theme-mangotango.css theme-playground.css theme-polarnight.css theme-snowstorm.csspublic_simplechat
datautils.mjs index.html readme.md simplechat.css simplechat.js simplechat_screens.webp ui.mjstests
unit
test_basic.py test_chat_completion.py test_compat_anthropic.py test_compat_oai_responses.py test_completion.py test_ctx_shift.py test_embedding.py test_infill.py test_lora.py test_rerank.py test_router.py test_security.py test_sleep.py test_slot_save.py test_speculative.py test_template.py test_tokenize.py test_tool_call.py test_vision_api.pywebui
.storybook
ModeWatcherDecorator.svelte TooltipProviderDecorator.svelte main.ts preview.ts vitest.setup.tssrc
lib
components
app
chat
ChatAttachments
ChatAttachmentPreview.svelte ChatAttachmentThumbnailFile.svelte ChatAttachmentThumbnailImage.svelte ChatAttachmentsList.svelte ChatAttachmentsViewAll.svelteChatForm
ChatFormActions
ChatFormActionFileAttachments.svelte ChatFormActionRecord.svelte ChatFormActionSubmit.svelte ChatFormActions.svelteChatMessages
ChatMessage.svelte ChatMessageActions.svelte ChatMessageAssistant.svelte ChatMessageBranchingControls.svelte ChatMessageEditForm.svelte ChatMessageStatistics.svelte ChatMessageSystem.svelte ChatMessageThinkingBlock.svelte ChatMessageUser.svelte ChatMessages.svelteChatScreen
ChatScreen.svelte ChatScreenDragOverlay.svelte ChatScreenHeader.svelte ChatScreenProcessingInfo.sveltedialogs
DialogChatAttachmentPreview.svelte DialogChatAttachmentsViewAll.svelte DialogChatError.svelte DialogChatSettings.svelte DialogConfirmation.svelte DialogConversationSelection.svelte DialogConversationTitleUpdate.svelte DialogEmptyFileAlert.svelte DialogModelInformation.svelte DialogModelNotAvailable.sveltemisc
ActionButton.svelte ActionDropdown.svelte BadgeChatStatistic.svelte BadgeInfo.svelte BadgeModality.svelte CodePreviewDialog.svelte ConversationSelection.svelte CopyToClipboardIcon.svelte KeyboardShortcutInfo.svelte MarkdownContent.svelte RemoveButton.svelte SearchInput.svelte SyntaxHighlightedCode.svelteui
alert-dialog
alert-dialog-action.svelte alert-dialog-cancel.svelte alert-dialog-content.svelte alert-dialog-description.svelte alert-dialog-footer.svelte alert-dialog-header.svelte alert-dialog-overlay.svelte alert-dialog-title.svelte alert-dialog-trigger.svelte index.tscard
card-action.svelte card-content.svelte card-description.svelte card-footer.svelte card-header.svelte card-title.svelte card.svelte index.tsdialog
dialog-close.svelte dialog-content.svelte dialog-description.svelte dialog-footer.svelte dialog-header.svelte dialog-overlay.svelte dialog-title.svelte dialog-trigger.svelte index.tsdropdown-menu
dropdown-menu-checkbox-item.svelte dropdown-menu-content.svelte dropdown-menu-group-heading.svelte dropdown-menu-group.svelte dropdown-menu-item.svelte dropdown-menu-label.svelte dropdown-menu-radio-group.svelte dropdown-menu-radio-item.svelte dropdown-menu-separator.svelte dropdown-menu-shortcut.svelte dropdown-menu-sub-content.svelte dropdown-menu-sub-trigger.svelte dropdown-menu-trigger.svelte index.tspopover
index.ts popover-close.svelte popover-content.svelte popover-portal.svelte popover-trigger.svelte popover.svelteselect
index.ts select-content.svelte select-group-heading.svelte select-group.svelte select-item.svelte select-label.svelte select-scroll-down-button.svelte select-scroll-up-button.svelte select-separator.svelte select-trigger.sveltesheet
index.ts sheet-close.svelte sheet-content.svelte sheet-description.svelte sheet-footer.svelte sheet-header.svelte sheet-overlay.svelte sheet-title.svelte sheet-trigger.sveltesidebar
constants.ts context.svelte.ts index.ts sidebar-content.svelte sidebar-footer.svelte sidebar-group-action.svelte sidebar-group-content.svelte sidebar-group-label.svelte sidebar-group.svelte sidebar-header.svelte sidebar-input.svelte sidebar-inset.svelte sidebar-menu-action.svelte sidebar-menu-badge.svelte sidebar-menu-button.svelte sidebar-menu-item.svelte sidebar-menu-skeleton.svelte sidebar-menu-sub-button.svelte sidebar-menu-sub-item.svelte sidebar-menu-sub.svelte sidebar-menu.svelte sidebar-provider.svelte sidebar-rail.svelte sidebar-separator.svelte sidebar-trigger.svelte sidebar.sveltetable
index.ts table-body.svelte table-caption.svelte table-cell.svelte table-footer.svelte table-head.svelte table-header.svelte table-row.svelte table.svelteconstants
auto-scroll.ts binary-detection.ts default-context.ts floating-ui-constraints.ts icons.ts input-classes.ts latex-protection.ts literal-html.ts localstorage-keys.ts max-bundle-size.ts precision.ts processing-info.ts settings-config.ts supported-file-types.ts table-html-restorer.ts tooltip-config.ts viewport.tsstores
chat.svelte.ts conversations.svelte.ts models.svelte.ts persisted.svelte.ts server.svelte.ts settings.svelte.tsutils
api-headers.ts api-key-validation.ts attachment-display.ts attachment-type.ts audio-recording.ts autoresize-textarea.ts branching.ts browser-only.ts clipboard.ts config-helpers.ts conversation-utils.ts convert-files-to-extra.ts file-preview.ts file-type.ts formatters.ts index.ts is-ime-composing.ts latex-protection.ts modality-file-validation.ts model-names.ts pdf-processing.ts portal-to-body.ts precision.ts process-uploaded-files.ts svg-to-png.ts syntax-highlight-language.ts text-files.ts text.ts webp-to-png.tstests
llama.cpp/ggml/src/ggml-cuda/common.cuh
raw
1#pragma once
2
3#include "ggml.h"
4#include "ggml-impl.h"
5#include "ggml-cuda.h"
6
7#include <cstdint>
8#include <memory>
9
10#if defined(GGML_USE_HIP)
11#define GGML_COMMON_DECL_HIP
12#define GGML_COMMON_IMPL_HIP
13#else
14#define GGML_COMMON_DECL_CUDA
15#define GGML_COMMON_IMPL_CUDA
16#if defined(GGML_USE_MUSA)
17#define GGML_COMMON_DECL_MUSA
18#define GGML_COMMON_IMPL_MUSA
19#endif
20#endif
21#include "ggml-common.h"
22
23#include <array>
24#include <algorithm>
25#include <cassert>
26#include <cfloat>
27#include <cstdio>
28#include <string>
29#include <unordered_map>
30#include <vector>
31
32#if defined(GGML_USE_HIP)
33#include "vendors/hip.h"
34#elif defined(GGML_USE_MUSA)
35#include "vendors/musa.h"
36#else
37#include "vendors/cuda.h"
38#endif // defined(GGML_USE_HIP)
39
40#define STRINGIZE_IMPL(...) #__VA_ARGS__
41#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
42
43#define WARP_SIZE 32
44#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
45#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
46
47#define GGML_CUDA_CC_PASCAL 600
48#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
49#define GGML_CUDA_CC_VOLTA 700
50#define GGML_CUDA_CC_TURING 750
51#define GGML_CUDA_CC_AMPERE 800
52#define GGML_CUDA_CC_ADA_LOVELACE 890
53// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
54// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
55#define GGML_CUDA_CC_BLACKWELL 1200
56#define GGML_CUDA_CC_DGX_SPARK 1210
57#define GGML_CUDA_CC_RUBIN 1300
58#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
59#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
60#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
61
62// AMD
63// GCN/CDNA, wave size is 64
64#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
65#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
66#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
67#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
68#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
69#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
70
71// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
72#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
73#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
74#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
75#define GGML_CUDA_CC_RDNA3_5 (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops.
76#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
77
78#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
79#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
80#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
81#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
82#define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5)
83#define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4)
84#define GGML_CUDA_CC_IS_RDNA3(cc) (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc))
85#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
86#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
87#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
88#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
89#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
90#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
91
92// Moore Threads
93#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons
94
95#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
96#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
97#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
98
99#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
100#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
101#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
102#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
103
104#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
105# define GGML_CUDA_USE_CUB
106#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
107
108#ifdef __CUDA_ARCH_LIST__
109constexpr bool ggml_cuda_has_arch_impl(int) {
110 return false;
111}
112
113template<class ... Archs>
114constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
115 return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
116}
117
118constexpr bool ggml_cuda_has_arch(const int arch) {
119 return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
120}
121
122constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
123 if (cur == 0) {
124 return -1;
125 }
126 return cur;
127}
128
129template<class ... Archs>
130constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
131 if (first <= arch && first > cur) {
132 return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
133 } else {
134 return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
135 }
136}
137
138constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
139 return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
140}
141#else
142static int ggml_cuda_highest_compiled_arch(const int arch) {
143 return arch;
144}
145#endif // __CUDA_ARCH_LIST__
146
147// ---------------------------------------------------------------------------------------------------------
148
149#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
150
151#define GGML_CUDA_MAX_STREAMS 8
152
153[[noreturn]]
154void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
155
156#define CUDA_CHECK_GEN(err, success, error_fn) \
157 do { \
158 auto err_ = (err); \
159 if (err_ != (success)) { \
160 ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
161 } \
162 } while (0)
163
164#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
165
166#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
167 static const char * cublas_get_error_str(const cublasStatus_t err) {
168 return cublasGetStatusString(err);
169 }
170#else
171 static const char * cublas_get_error_str(const cublasStatus_t err) {
172 switch (err) {
173 case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
174 case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
175 case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
176 case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
177 case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
178 case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
179 case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
180 case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
181 case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
182 default: return "unknown error";
183 }
184 }
185#endif // CUDART_VERSION >= 12000
186
187#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
188
189#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
190static const char * cu_get_error_str(CUresult err) {
191 const char * err_str;
192 cuGetErrorString(err, &err_str);
193 return err_str;
194}
195#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
196#endif
197
198#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
199# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
200 do { \
201 static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
202 const int id = ggml_cuda_get_device(); \
203 if (!shared_memory_limit_raised[id]) { \
204 CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
205 shared_memory_limit_raised[id] = true; \
206 } \
207 } while (0)
208#else
209# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
210 do { \
211 GGML_UNUSED(nbytes); \
212 } while (0)
213#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
214
215#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
216#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
217#else
218#define GGML_CUDA_ASSUME(x)
219#endif // CUDART_VERSION >= 11010
220
221#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
222#define GGML_USE_VMM
223#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
224
225#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
226#define FP16_AVAILABLE
227#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
228
229#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
230#define FAST_FP16_AVAILABLE
231#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
232
233#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
234#define AMD_MFMA_AVAILABLE
235#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
236
237#if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3))
238#define AMD_WMMA_AVAILABLE
239#endif // defined(GGML_USE_HIP) && defined(RDNA4)
240
241// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
242#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
243#define VOLTA_MMA_AVAILABLE
244#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
245
246#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
247#define TURING_MMA_AVAILABLE
248#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
249
250#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
251#define AMPERE_MMA_AVAILABLE
252#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
253
254#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
255# define BLACKWELL_MMA_AVAILABLE
256#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
257
258#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
259#define CP_ASYNC_AVAILABLE
260#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
261
262#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
263#define FLASH_ATTN_AVAILABLE
264#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
265
266#if defined(TURING_MMA_AVAILABLE)
267#define LDMATRIX_TRANS_AVAILABLE
268#endif // defined(TURING_MMA_AVAILABLE)
269
270static bool fp16_available(const int cc) {
271 return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
272 (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
273}
274
275static bool fast_fp16_available(const int cc) {
276 return GGML_CUDA_CC_IS_AMD(cc) ||
277 (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
278 (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
279}
280
281// To be used for feature selection of external libraries, e.g. cuBLAS.
282static bool fast_fp16_hardware_available(const int cc) {
283 return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
284 (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
285}
286
287// To be used for feature selection of external libraries, e.g. cuBLAS.
288static bool fp16_mma_hardware_available(const int cc) {
289 return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
290 GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
291 (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
292}
293
294static bool bf16_mma_hardware_available(const int cc) {
295 return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
296 GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
297 (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
298}
299
300static bool fp32_mma_hardware_available(const int cc) {
301 return GGML_CUDA_CC_IS_CDNA(cc);
302}
303
304static bool amd_mfma_available(const int cc) {
305#if !defined(GGML_HIP_NO_MMQ_MFMA)
306 return GGML_CUDA_CC_IS_CDNA(cc);
307#else
308 return false;
309#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
310}
311
312static bool amd_wmma_available(const int cc) {
313 return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc));
314}
315
316static bool volta_mma_available(const int cc) {
317 return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
318}
319
320static bool turing_mma_available(const int cc) {
321 return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
322}
323
324static bool ampere_mma_available(const int cc) {
325 return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
326}
327
328static bool cp_async_available(const int cc) {
329 return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
330}
331
332static bool blackwell_mma_available(const int cc) {
333 return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
334 ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
335}
336
337static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
338#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
339 return 64;
340#else
341 return 32;
342#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
343}
344
345// Maximum number of bytes that can be copied in a single instruction.
346static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
347#ifdef GGML_USE_HIP
348 return 16;
349#else
350#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
351 return 16;
352#else
353 return 8;
354#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
355#endif // GGML_USE_HIP
356}
357
358
359[[noreturn]]
360static __device__ void no_device_code(
361 const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
362
363#if defined(GGML_USE_HIP)
364 printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
365 file_name, line, function_name, arch);
366 GGML_UNUSED(arch_list);
367#else
368 printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
369 file_name, line, function_name, arch, arch_list);
370#endif // defined(GGML_USE_HIP)
371 __trap();
372
373 GGML_UNUSED(no_device_code); // suppress unused function warning
374
375#if defined(GGML_USE_MUSA)
376 __builtin_unreachable();
377#endif // defined(GGML_USE_MUSA)
378}
379
380#ifdef __CUDA_ARCH__
381#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
382#else
383#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
384#endif // __CUDA_ARCH__
385
386// The compiler is always able to unroll loops if they contain continue expressions.
387// In such cases loop unrolling can still be achieved via recursion:
388template <int n>
389struct ggml_cuda_unroll {
390 template <typename Func, typename... Args>
391 __device__ void operator()(const Func & f, Args... args) const {
392 f(n - 1, args...);
393 ggml_cuda_unroll<n - 1>{}(f, args...);
394 }
395};
396
397template <>
398struct ggml_cuda_unroll<1> {
399 template <typename Func, typename... Args>
400 __device__ void operator()(const Func & f, Args... args) const {
401 f(0, args...);
402 }
403};
404
405template<int width = WARP_SIZE>
406static __device__ __forceinline__ int warp_reduce_sum(int x) {
407#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
408 return __reduce_add_sync(0xffffffff, x);
409#else
410#pragma unroll
411 for (int offset = width/2; offset > 0; offset >>= 1) {
412 x += __shfl_xor_sync(0xffffffff, x, offset, width);
413 }
414 return x;
415#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
416}
417
418template<int width = WARP_SIZE>
419static __device__ __forceinline__ float warp_reduce_sum(float x) {
420#pragma unroll
421 for (int offset = width/2; offset > 0; offset >>= 1) {
422 x += __shfl_xor_sync(0xffffffff, x, offset, width);
423 }
424 return x;
425}
426
427template<int width = WARP_SIZE>
428static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
429#pragma unroll
430 for (int offset = width/2; offset > 0; offset >>= 1) {
431 a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
432 a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
433 }
434 return a;
435}
436
437template<int width = WARP_SIZE>
438static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
439#ifdef FP16_AVAILABLE
440#pragma unroll
441 for (int offset = width/2; offset > 0; offset >>= 1) {
442 a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
443 }
444 return a;
445
446#else
447 NO_DEVICE_CODE;
448 return a;
449#endif // FP16_AVAILABLE
450}
451
452template<int width = WARP_SIZE>
453static __device__ __forceinline__ int warp_reduce_all(int x) {
454 if (width == ggml_cuda_get_physical_warp_size()) {
455 return __all_sync(0xffffffff, x);
456 } else {
457#pragma unroll
458 for (int offset = width/2; offset > 0; offset >>= 1) {
459 x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
460 }
461 return x;
462 }
463}
464
465template<int width = WARP_SIZE>
466static __device__ __forceinline__ int warp_reduce_any(int x) {
467 if (width == ggml_cuda_get_physical_warp_size()) {
468 return __any_sync(0xffffffff, x);
469 } else {
470#pragma unroll
471 for (int offset = width/2; offset > 0; offset >>= 1) {
472 x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
473 }
474 return x;
475 }
476}
477
478template<int width = WARP_SIZE>
479static __device__ __forceinline__ float warp_reduce_max(float x) {
480#pragma unroll
481 for (int offset = width/2; offset > 0; offset >>= 1) {
482 x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
483 }
484 return x;
485}
486
487template<typename T, int width = WARP_SIZE>
488static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
489 const int lane_id = threadIdx.x % width;
490#pragma unroll
491 for (int offset = 1; offset < width; offset <<= 1) {
492 const T t = __shfl_up_sync(0xffffffff, x, offset, width);
493 if (lane_id >= offset) {
494 x += t;
495 }
496 }
497 return x;
498}
499
500template<int width = WARP_SIZE>
501static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
502 const int lane_id = threadIdx.x % width;
503#pragma unroll
504 for (int offset = 1; offset < width; offset <<= 1) {
505 const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
506 const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
507 if (lane_id >= offset) {
508 a.x += t_x;
509 a.y += t_y;
510 }
511 }
512 return a;
513}
514
515template<int width = WARP_SIZE>
516static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
517#ifdef FP16_AVAILABLE
518 const int lane_id = threadIdx.x % width;
519#pragma unroll
520 for (int offset = 1; offset < width; offset <<= 1) {
521 const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
522 if (lane_id >= offset) {
523 a = __hadd2(a, t);
524 }
525 }
526 return a;
527
528#else
529 NO_DEVICE_CODE;
530 return a;
531#endif // FP16_AVAILABLE
532}
533
534enum class block_reduce_method {
535 MAX,
536 SUM,
537};
538
539template<block_reduce_method method_t, typename T>
540struct block_reduce_policy;
541
542template <typename T, typename... Ts>
543inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
544
545template<typename...>
546inline constexpr bool ggml_cuda_dependent_false_v = false;
547
548template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {
549 static __device__ T reduce(T val) {
550 if constexpr(is_any<T, float, float2, half2, int>) {
551 return warp_reduce_sum(val);
552 } else {
553 static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
554 }
555 }
556
557 static __device__ T sentinel() {
558 if constexpr (std::is_same_v<T, float>) {
559 return 0.0f;
560 } else if constexpr (std::is_same_v<T, float2>) {
561 return make_float2(0.0f, 0.0f);
562 } else if constexpr (std::is_same_v<T, half2>) {
563 return make_half2(0.0f, 0.0f);
564 } else if constexpr (std::is_same_v<T, int>) {
565 return 0;
566 } else {
567 static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
568 }
569 }
570};
571
572template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {
573 static __device__ T reduce(T val) {
574 if constexpr (is_any<T, float, half2>) {
575 return warp_reduce_max(val);
576 } else {
577 static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
578 }
579 }
580
581 static __device__ T sentinel() {
582 if constexpr (std::is_same_v<T, float>) {
583 return -INFINITY;
584 } else if constexpr (std::is_same_v<T, half2>) {
585 return make_half2(-INFINITY, -INFINITY);
586 } else {
587 static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
588 }
589 }
590};
591
592template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>
593static __device__ T block_reduce(T val, T * shared_vals) {
594 val = block_reduce_policy<reduce_method_t, T>::reduce(val);
595 const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
596 if (block_size > WARP_SIZE) {
597 assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
598 const int warp_id = threadIdx.x / WARP_SIZE;
599 const int lane_id = threadIdx.x % WARP_SIZE;
600 if (lane_id == 0) {
601 shared_vals[warp_id] = val;
602 }
603 __syncthreads();
604 val = block_reduce_policy<reduce_method_t, T>::sentinel();
605 if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
606 val = shared_vals[lane_id];
607 }
608 return block_reduce_policy<reduce_method_t, T>::reduce(val);
609 }
610
611 return val;
612}
613
614static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
615#ifdef FP16_AVAILABLE
616
617#if !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
618 return __float2half(fmaxf(__half2float(a), __half2float(b)));
619#else
620 return __hmax(a, b);
621#endif // !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
622
623#else
624 NO_DEVICE_CODE;
625 GGML_UNUSED(b);
626 return a;
627#endif // FP16_AVAILABLE
628}
629
630static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
631#if defined(GGML_USE_HIP)
632 return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
633#elif CUDART_VERSION >= CUDART_HMAX
634 return __hmax2(a, b);
635#else
636 half2 ret;
637 reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
638 reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
639 return ret;
640#endif
641}
642
643template<int width = WARP_SIZE>
644static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
645#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
646#pragma unroll
647 for (int offset = width/2; offset > 0; offset >>= 1) {
648 x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
649 }
650 return x;
651#else
652 GGML_UNUSED(x);
653 NO_DEVICE_CODE;
654#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
655}
656
657#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \
658 (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
659static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
660 const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
661 const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
662 return mask_low | mask_high;
663}
664#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
665
666static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
667#if defined(GGML_USE_HIP)
668#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
669 c = __builtin_amdgcn_sdot4(a, b, c, false);
670#elif defined(RDNA3) || defined(RDNA4)
671 c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
672#elif defined(RDNA1) || defined(__gfx900__)
673 int tmp1;
674 int tmp2;
675 asm("\n \
676 v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
677 v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
678 v_add3_u32 %0, %1, %2, %0 \n \
679 v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
680 v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
681 v_add3_u32 %0, %1, %2, %0 \n \
682 "
683 : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
684 : "v"(a), "v"(b)
685 );
686#else
687 const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
688 const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
689 c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
690#endif
691 return c;
692
693#else // defined(GGML_USE_HIP)
694
695#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
696 return __dp4a(a, b, c);
697#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
698 const int8_t * a8 = (const int8_t *) &a;
699 const int8_t * b8 = (const int8_t *) &b;
700 return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
701#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
702
703#endif // defined(GGML_USE_HIP)
704}
705
706static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
707 acc += v*u;
708}
709
710static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
711 acc += v.x*u.x;
712 acc += v.y*u.y;
713}
714
715#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
716#define V_DOT2_F32_F16_AVAILABLE
717#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
718
719static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
720#ifdef V_DOT2_F32_F16_AVAILABLE
721 asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
722#else
723#ifdef FAST_FP16_AVAILABLE
724 const float2 tmp = __half22float2(v*u);
725 acc += tmp.x + tmp.y;
726#else
727 const float2 tmpv = __half22float2(v);
728 const float2 tmpu = __half22float2(u);
729 acc += tmpv.x * tmpu.x;
730 acc += tmpv.y * tmpu.y;
731#endif // FAST_FP16_AVAILABLE
732#endif // V_DOT2_F32_F16_AVAILABLE
733}
734
735static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
736#ifdef FAST_FP16_AVAILABLE
737 acc += v*u;
738#else
739 const float2 tmpv = __half22float2(v);
740 const float2 tmpu = __half22float2(u);
741 float2 tmpacc = __half22float2(acc);
742 tmpacc.x += tmpv.x * tmpu.x;
743 tmpacc.y += tmpv.y * tmpu.y;
744 acc = make_half2(tmpacc.x, tmpacc.y);
745#endif // FAST_FP16_AVAILABLE
746}
747
748// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
749// Important: do not use this function if dst and src both point at registers.
750// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
751// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
752// If dst and src point at different address spaces then they are guaranteed to not be aliased.
753template <int nbytes, int alignment = 0>
754static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
755 static_assert(
756 nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,
757 "You are misusing the alignment parameter for ggml_cuda_memcpy_1. "
758 "The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. "
759 "If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. "
760 "Call ggml_cuda_memcpy_1 in a loop instead.");
761 if constexpr (alignment != 0) {
762 static_assert(nbytes % alignment == 0, "bad alignment");
763 }
764 constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
765
766#pragma unroll
767 for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
768 if constexpr (nb_per_cpy == 1) {
769 ((char *) dst)[i] = ((const char *) src)[i];
770 } else if constexpr (nb_per_cpy == 2) {
771 ((short *) dst)[i] = ((const short *) src)[i];
772 } else if constexpr (nb_per_cpy == 4) {
773 ((int *) dst)[i] = ((const int *) src)[i];
774 } else if constexpr (nb_per_cpy == 8) {
775 ((int2 *) dst)[i] = ((const int2 *) src)[i];
776 } else if constexpr (nb_per_cpy == 16) {
777 ((int4 *) dst)[i] = ((const int4 *) src)[i];
778 } else {
779 static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
780 }
781 }
782}
783
784static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
785#if CUDART_VERSION >= 12080
786 const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
787 return (float) e;
788#else
789 uint32_t bits;
790 if (x == 0) {
791 bits = 0x00400000;
792 } else {
793 bits = (uint32_t) x << 23;
794 }
795
796 float result;
797 memcpy(&result, &bits, sizeof(float));
798 return result;
799#endif // CUDART_VERSION >= 12050
800}
801
802__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
803 const uint8_t sign_bit = (x < 0.0f) << 3;
804 float ax = fabsf(x) * e;
805
806 // Positive LUT
807 static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
808
809 int best_i = 0;
810 float best_err = fabsf(ax - pos_lut[0]);
811
812#pragma unroll
813 for (int i = 1; i < 8; ++i) {
814 const float err = fabsf(ax - pos_lut[i]);
815 if (err < best_err) {
816 best_err = err;
817 best_i = i;
818 }
819 }
820
821 return static_cast<uint8_t>(best_i | sign_bit);
822}
823
824// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
825// Precompute mp (m' in the paper) and L such that division
826// can be computed using a multiply (high 32b of 64b result)
827// and a shift:
828//
829// n/d = (mulhi(n, mp) + n) >> L;
830static const uint3 init_fastdiv_values(uint64_t d_64) {
831 GGML_ASSERT(d_64 != 0);
832 GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
833
834 uint32_t d = (uint32_t)d_64;
835
836 // compute L = ceil(log2(d));
837 uint32_t L = 0;
838 while (L < 32 && (uint32_t{ 1 } << L) < d) {
839 L++;
840 }
841
842 uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
843 // pack divisor as well to reduce error surface
844 return make_uint3(mp, L, d);
845}
846
847static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
848 // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
849 // fastdiv_values.z is unused and optimized away by the compiler.
850 // Compute high 32 bits of n * mp
851 const uint32_t hi = __umulhi(n, fastdiv_values.x);
852 // add n, apply bit shift
853 return (hi + n) >> fastdiv_values.y;
854}
855
856static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
857 // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
858 return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
859}
860
861// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
862static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
863 // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
864 const uint32_t div_val = fastdiv(n, fastdiv_values);
865 const uint32_t mod_val = n - div_val * fastdiv_values.z;
866 return make_uint2(div_val, mod_val);
867}
868
869typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
870
871static __device__ __forceinline__ float get_alibi_slope(
872 const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
873) {
874 if (max_bias <= 0.0f) {
875 return 1.0f;
876 }
877 const float base = h < n_head_log2 ? m0 : m1;
878 const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
879
880 return powf(base, exph);
881}
882
883template <ggml_type type>
884struct ggml_cuda_type_traits;
885
886template<>
887struct ggml_cuda_type_traits<GGML_TYPE_F16> {
888 static constexpr int qk = 1;
889 static constexpr int qr = 1;
890};
891
892template<>
893struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
894 static constexpr int qk = QK4_0;
895 static constexpr int qr = QR4_0;
896 static constexpr int qi = QI4_0;
897};
898
899template<>
900struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
901 static constexpr int qk = QK4_1;
902 static constexpr int qr = QR4_1;
903 static constexpr int qi = QI4_1;
904};
905
906template<>
907struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
908 static constexpr int qk = QK5_0;
909 static constexpr int qr = QR5_0;
910 static constexpr int qi = QI5_0;
911};
912
913template<>
914struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
915 static constexpr int qk = QK5_1;
916 static constexpr int qr = QR5_1;
917 static constexpr int qi = QI5_1;
918};
919
920template<>
921struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
922 static constexpr int qk = QK8_0;
923 static constexpr int qr = QR8_0;
924 static constexpr int qi = QI8_0;
925};
926
927template<>
928struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
929 static constexpr int qk = QK_MXFP4;
930 static constexpr int qr = QR_MXFP4;
931 static constexpr int qi = QI_MXFP4;
932};
933
934template<>
935struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
936 static constexpr int qk = QK_K;
937 static constexpr int qr = QR2_K;
938 static constexpr int qi = QI2_K;
939};
940
941template<>
942struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
943 static constexpr int qk = QK_K;
944 static constexpr int qr = QR3_K;
945 static constexpr int qi = QI3_K;
946};
947
948template<>
949struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
950 static constexpr int qk = QK_K;
951 static constexpr int qr = QR4_K;
952 static constexpr int qi = QI4_K;
953};
954
955template<>
956struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
957 static constexpr int qk = QK_K;
958 static constexpr int qr = QR5_K;
959 static constexpr int qi = QI5_K;
960};
961
962template<>
963struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
964 static constexpr int qk = QK_K;
965 static constexpr int qr = QR6_K;
966 static constexpr int qi = QI6_K;
967};
968
969template<>
970struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
971 static constexpr int qk = QK_K;
972 static constexpr int qr = QR2_XXS;
973 static constexpr int qi = QI2_XXS;
974};
975
976template<>
977struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
978 static constexpr int qk = QK_K;
979 static constexpr int qr = QR2_XS;
980 static constexpr int qi = QI2_XS;
981};
982
983template<>
984struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
985 static constexpr int qk = QK_K;
986 static constexpr int qr = QR2_S;
987 static constexpr int qi = QI2_S;
988};
989
990template<>
991struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
992 static constexpr int qk = QK_K;
993 static constexpr int qr = QR3_XXS;
994 static constexpr int qi = QI3_XXS;
995};
996
997template<>
998struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
999 static constexpr int qk = QK_K;
1000 static constexpr int qr = QR1_S;
1001 static constexpr int qi = QI1_S;
1002};
1003
1004template<>
1005struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
1006 static constexpr int qk = QK_K;
1007 static constexpr int qr = QR1_M;
1008 static constexpr int qi = QI1_M;
1009};
1010
1011template<>
1012struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
1013 static constexpr int qk = QK4_NL;
1014 static constexpr int qr = QR4_NL;
1015 static constexpr int qi = QI4_NL;
1016};
1017
1018template<>
1019struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
1020 static constexpr int qk = QK_K;
1021 static constexpr int qr = QR4_XS;
1022 static constexpr int qi = QI4_XS;
1023};
1024
1025template<>
1026struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
1027 static constexpr int qk = QK_K;
1028 static constexpr int qr = QR3_S;
1029 static constexpr int qi = QI3_S;
1030};
1031
1032//////////////////////
1033
1034struct ggml_cuda_device_info {
1035 int device_count;
1036
1037 struct cuda_device_info {
1038 int cc; // compute capability
1039 int nsm; // number of streaming multiprocessors
1040 size_t smpb; // max. shared memory per block
1041 size_t smpbo; // max. shared memory per block (with opt-in)
1042 bool integrated; // Device is integrated as opposed to discrete
1043 bool vmm; // virtual memory support
1044 size_t vmm_granularity; // granularity of virtual memory
1045 size_t total_vram;
1046 int warp_size; // Number of threads in a dispatch
1047 bool supports_cooperative_launch; // whether cooperative launch is supported
1048 };
1049
1050 cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
1051
1052 std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
1053};
1054
1055const ggml_cuda_device_info & ggml_cuda_info();
1056
1057void ggml_cuda_set_device(int device);
1058int ggml_cuda_get_device();
1059
1060struct ggml_cuda_pool {
1061 virtual ~ggml_cuda_pool() = default;
1062
1063 virtual void * alloc(size_t size, size_t * actual_size) = 0;
1064 virtual void free(void * ptr, size_t size) = 0;
1065};
1066
1067template<typename T>
1068struct ggml_cuda_pool_alloc {
1069 ggml_cuda_pool * pool = nullptr;
1070 T * ptr = nullptr;
1071 size_t actual_size = 0;
1072
1073 ggml_cuda_pool_alloc() = default;
1074
1075 explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
1076 }
1077
1078 ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
1079 alloc(size);
1080 }
1081
1082 ~ggml_cuda_pool_alloc() {
1083 if (ptr != nullptr) {
1084 pool->free(ptr, actual_size);
1085 }
1086 }
1087
1088 // size is in number of elements
1089 T * alloc(size_t size) {
1090 GGML_ASSERT(pool != nullptr);
1091 GGML_ASSERT(ptr == nullptr);
1092 ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
1093 return ptr;
1094 }
1095
1096 T * alloc(ggml_cuda_pool & pool, size_t size) {
1097 this->pool = &pool;
1098 return alloc(size);
1099 }
1100
1101 T * get() {
1102 return ptr;
1103 }
1104
1105 ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
1106 ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
1107 ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
1108 ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
1109};
1110
1111
1112// backend interface
1113
1114struct ggml_tensor_extra_gpu {
1115 void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
1116 cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
1117};
1118
1119
1120#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
1121#define USE_CUDA_GRAPH
1122#endif
1123
1124struct ggml_cuda_graph_node_properties {
1125 void * node_data;
1126 ggml_op node_op;
1127 enum ggml_type node_type;
1128 int32_t flags;
1129 int64_t ne[GGML_MAX_DIMS];
1130 size_t nb[GGML_MAX_DIMS];
1131 void * src_data[GGML_MAX_SRC];
1132 int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
1133};
1134
1135static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
1136
1137struct ggml_cuda_graph {
1138#ifdef USE_CUDA_GRAPH
1139 ~ggml_cuda_graph() {
1140 if (instance != nullptr) {
1141 CUDA_CHECK(cudaGraphExecDestroy(instance));
1142 }
1143 if (graph != nullptr) {
1144 CUDA_CHECK(cudaGraphDestroy(graph));
1145 }
1146 }
1147 cudaGraph_t graph = nullptr;
1148 cudaGraphExec_t instance = nullptr;
1149 size_t num_nodes = 0;
1150 std::vector<cudaGraphNode_t> nodes;
1151 bool disable_due_to_gpu_arch = false;
1152 bool disable_due_to_too_many_updates = false;
1153 int number_consecutive_updates = 0;
1154 std::vector<ggml_cuda_graph_node_properties> props;
1155
1156 // these are extra tensors (inputs) that participate in the ggml graph but are not nodes
1157 // they properties also have to match in order to be able to safely reuse a CUDA graph
1158 // ref: https://github.com/ggml-org/llama.cpp/pull/18583
1159 // ref: https://github.com/ggml-org/llama.cpp/pull/19165
1160 std::vector<ggml_cuda_graph_node_properties> extra;
1161
1162 void record_update(bool use_graph, bool update_required) {
1163 if (use_graph && update_required) {
1164 number_consecutive_updates++;
1165 } else {
1166 number_consecutive_updates = 0;
1167 }
1168 if (number_consecutive_updates >= 4) {
1169 GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
1170 disable_due_to_too_many_updates = true;
1171 }
1172 }
1173
1174 bool is_enabled() const {
1175 static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
1176 return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
1177 }
1178#endif
1179};
1180
1181struct ggml_cuda_concurrent_event {
1182 std::vector<cudaEvent_t> join_events;
1183 cudaEvent_t fork_event = nullptr;
1184
1185 int n_streams = 0;
1186 std::unordered_map<const ggml_tensor *, int> stream_mapping;
1187
1188 // Original order of nodes in this concurrent region (before interleaving)
1189 // Used to restore grouping for fusion within streams
1190 std::vector<const ggml_tensor *> original_order;
1191
1192 const ggml_tensor * join_node;
1193
1194 ggml_cuda_concurrent_event() = default;
1195
1196 ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
1197 ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
1198
1199 explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
1200 join_events.resize(n_streams);
1201
1202 for (size_t i = 0; i < join_events.size(); ++i) {
1203 CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
1204 }
1205
1206 CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
1207 }
1208
1209 ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
1210 : join_events(std::move(other.join_events))
1211 , fork_event(other.fork_event)
1212 , n_streams(other.n_streams)
1213 , stream_mapping(std::move(other.stream_mapping))
1214 , original_order(std::move(other.original_order))
1215 , join_node(other.join_node) {
1216 other.fork_event = nullptr;
1217 }
1218
1219 // 1. check if any branches write to overlapping memory ranges (except the join node)
1220 // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
1221 // we assume all nodes have the same buffer
1222 bool is_valid() const {
1223 std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
1224 write_ranges.resize(n_streams);
1225
1226 // get join_node's memory range to exclude from overlap checking.
1227 // multiple nodes can use join_node's buffer; we synchronize on the join node.
1228 const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
1229 const int64_t join_start = (int64_t) join_t->data;
1230 const int64_t join_end = join_start + ggml_nbytes(join_t);
1231
1232 for (const auto & [tensor, stream] : stream_mapping) {
1233 const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
1234 const int64_t t_start = (int64_t) t->data;
1235 const int64_t t_end = t_start + ggml_nbytes(t);
1236
1237 // skip tensors that overlap with join_node's buffer.
1238 if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
1239 continue;
1240 }
1241
1242 // concurrent streams begin from 1
1243 write_ranges[stream - 1].emplace_back(t_start, t_end);
1244 }
1245
1246 for (int i = 0; i < n_streams; ++i) {
1247 // sorts first by start then by end of write range
1248 std::sort(write_ranges[i].begin(), write_ranges[i].end());
1249 }
1250
1251 bool writes_overlap = false;
1252 bool dependent_srcs = false;
1253 for (const auto & [tensor, stream] : stream_mapping) {
1254 const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
1255 const int64_t t_start = (int64_t) t->data;
1256 const int64_t t_end = t_start + ggml_nbytes(t);
1257
1258 // skip tensors that overlap with join_node's buffer
1259 if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
1260 continue;
1261 }
1262
1263 // check if this buffer's write data overlaps with another stream's
1264 std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
1265 for (int i = 0; i < n_streams; ++i) {
1266 if (i == stream - 1) {
1267 continue;
1268 }
1269 auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
1270
1271 if (it != write_ranges[i].end()) {
1272 const std::pair<int64_t, int64_t> & other = *it;
1273
1274 // std::lower_bound returns the first element where other >= data_range (lexicographically).
1275 // This guarantees other.first >= data_range.first.
1276 // Therefore, overlap occurs iff other.first < data_range.second
1277 // (i.e., the other range starts before this range ends).
1278 if (other.first < data_range.second) {
1279 GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
1280 writes_overlap = true;
1281 break;
1282 }
1283 }
1284 }
1285
1286 //check if all srcs are either in branch or don't have a branch
1287 for (int i = 0; i < GGML_MAX_SRC; ++i) {
1288 if (!tensor->src[i]) {
1289 continue;
1290 }
1291
1292 auto it = stream_mapping.find(tensor->src[i]);
1293
1294 if (it == stream_mapping.end()) {
1295 continue;
1296 }
1297
1298 if (it->second != stream) {
1299 dependent_srcs = true;
1300 break;
1301 }
1302 }
1303
1304 if (dependent_srcs || writes_overlap) {
1305 break;
1306 }
1307 }
1308
1309 return !writes_overlap && !dependent_srcs;
1310 }
1311
1312 ~ggml_cuda_concurrent_event() {
1313 if (fork_event != nullptr) {
1314 CUDA_CHECK(cudaEventDestroy(fork_event));
1315 }
1316 for (cudaEvent_t e : join_events) {
1317 if (e != nullptr) {
1318 CUDA_CHECK(cudaEventDestroy(e));
1319 }
1320 }
1321 }
1322};
1323
1324struct ggml_cuda_stream_context {
1325 std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
1326
1327 void reset() {
1328 concurrent_events.clear();
1329 }
1330};
1331
1332struct ggml_backend_cuda_context {
1333 int device;
1334 std::string name;
1335 cudaEvent_t copy_event = nullptr;
1336
1337 cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
1338 cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
1339
1340 int curr_stream_no = 0;
1341
1342#ifdef USE_CUDA_GRAPH
1343 // Map from first_node_ptr to cuda_graph - allows multiple graphs per context
1344 // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
1345 std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
1346
1347 ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
1348 auto it = cuda_graphs.find(first_node_ptr);
1349 if (it == cuda_graphs.end()) {
1350 cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
1351 return cuda_graphs[first_node_ptr].get();
1352 }
1353 return it->second.get();
1354 }
1355
1356 // Check if any CUDA graph is enabled for this context (used by kernels that need to know
1357 // if graphs are in use without having access to the specific graph key)
1358 bool any_cuda_graph_enabled() const {
1359 for (const auto & [key, graph] : cuda_graphs) {
1360 if (graph && graph->is_enabled()) {
1361 return true;
1362 }
1363 }
1364 return false;
1365 }
1366
1367 // Check if any CUDA graph has an instance for this context
1368 bool any_cuda_graph_has_instance() const {
1369 for (const auto & [key, graph] : cuda_graphs) {
1370 if (graph && graph->instance != nullptr) {
1371 return true;
1372 }
1373 }
1374 return false;
1375 }
1376#endif // USE_CUDA_GRAPH
1377
1378 explicit ggml_backend_cuda_context(int device) :
1379 device(device),
1380 name(GGML_CUDA_NAME + std::to_string(device)) {
1381 }
1382
1383 ggml_cuda_stream_context concurrent_stream_context;
1384
1385 ~ggml_backend_cuda_context();
1386
1387 cudaStream_t stream(int device, int stream) {
1388 if (streams[device][stream] == nullptr) {
1389 ggml_cuda_set_device(device);
1390 CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
1391 }
1392 return streams[device][stream];
1393 }
1394
1395 cudaStream_t stream() { return stream(device, curr_stream_no); }
1396
1397 ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
1398
1399 cublasHandle_t cublas_handle(int device) {
1400 if (cublas_handles[device] == nullptr) {
1401 ggml_cuda_set_device(device);
1402 CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
1403 CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
1404 }
1405 return cublas_handles[device];
1406 }
1407
1408 cublasHandle_t cublas_handle() {
1409 return cublas_handle(device);
1410 }
1411
1412 // pool
1413 std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
1414
1415 static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
1416
1417 ggml_cuda_pool & pool(int device) {
1418 if (pools[device][curr_stream_no] == nullptr) {
1419 pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
1420 }
1421 return *pools[device][curr_stream_no];
1422 }
1423
1424 ggml_cuda_pool & pool() {
1425 return pool(device);
1426 }
1427};
1428
1429struct ggml_cuda_mm_fusion_args_host {
1430 const ggml_tensor * x_bias = nullptr;
1431 const ggml_tensor * gate = nullptr;
1432 const ggml_tensor * gate_bias = nullptr;
1433 ggml_glu_op glu_op;
1434};
1435struct ggml_cuda_mm_fusion_args_device {
1436 const void * x_bias = nullptr;
1437 const void * gate = nullptr;
1438 const void * gate_bias = nullptr;
1439 ggml_glu_op glu_op;
1440};