summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cuda
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cuda
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cuda')
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt259
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/acc.cu61
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/acc.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/add-id.cu58
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/add-id.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/arange.cu34
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/arange.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/argmax.cu91
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/argmax.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/argsort.cu230
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/argsort.cuh19
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/binbcast.cu504
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/binbcast.cuh11
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/clamp.cu45
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/clamp.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/common.cuh1440
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/concat.cu221
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/concat.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu86
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu161
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu91
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh4
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv2d.cu166
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/conv2d.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/convert.cu825
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/convert.cuh56
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/count-equal.cu64
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/count-equal.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cp-async.cuh57
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh217
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cpy.cu555
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cpy.cuh7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu177
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cuh7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cumsum.cu307
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/cumsum.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/dequantize.cuh77
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/diag.cu77
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/diag.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/diagmask.cu40
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/diagmask.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh1036
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh1750
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu49
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh1256
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh586
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu675
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh51
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn.cu482
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fattn.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fill.cu37
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/fill.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/getrows.cu286
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/getrows.cuh15
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu5118
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/gla.cu93
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/gla.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/im2col.cu264
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/im2col.cuh6
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mean.cu75
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mean.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mma.cuh1381
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmf.cu191
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmf.cuh908
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmid.cu164
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmid.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmq.cu366
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmq.cuh4092
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmvf.cu862
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmvf.cuh14
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmvq.cu767
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/mmvq.cuh12
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/norm.cu672
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/norm.cuh18
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cu78
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu49
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/out-prod.cu68
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/out-prod.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/pad.cu106
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/pad.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu91
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/pool2d.cu94
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/pool2d.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/quantize.cu343
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/quantize.cuh41
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh39
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/roll.cu67
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/roll.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/rope.cu665
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/rope.cuh9
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/scale.cu34
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/scale.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/set-rows.cu330
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/set-rows.cuh7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/set.cu39
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/set.cuh7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/softcap.cu34
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/softcap.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/softmax.cu472
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/softmax.cuh7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/solve_tri.cu275
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/solve_tri.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu150
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/ssm-conv.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu342
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/ssm-scan.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/sum.cu41
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/sum.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/sumrows.cu43
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/sumrows.cuh4
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu11
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu11
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu11
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu11
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu10
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu7
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu7
-rwxr-xr-xllama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py99
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/top-k.cu95
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/top-k.cuh3
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/topk-moe.cu403
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh27
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/tri.cu136
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/tri.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/tsembd.cu47
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/tsembd.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/unary.cu562
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/unary.cuh110
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/upscale.cu293
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/upscale.cuh5
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh1223
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h23
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/vendors/hip.h278
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/vendors/musa.h147
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/wkv.cu199
-rw-r--r--llama.cpp/ggml/src/ggml-cuda/wkv.cuh7
235 files changed, 35073 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt b/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt
new file mode 100644
index 0000000..262f882
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt
@@ -0,0 +1,259 @@
+cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
+
+find_package(CUDAToolkit)
+
+if (CUDAToolkit_FOUND)
+ message(STATUS "CUDA Toolkit found")
+
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
+ # native == GPUs available at build time
+ # 50 == Maxwell, lowest CUDA 12 standard
+ # 60 == P100, FP16 CUDA intrinsics
+ # 61 == Pascal, __dp4a instruction (per-byte integer dot product)
+ # 70 == V100, FP16 tensor cores
+ # 75 == Turing, int8 tensor cores
+ # 80 == Ampere, asynchronous data loading, faster tensor core instructions
+ # 86 == RTX 3000, needs CUDA v11.1
+ # 89 == RTX 4000, needs CUDA v11.8
+ # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores
+ #
+ # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
+ # XX-real == compile CUDA code as device code for this specific architecture
+ # no suffix == compile as both PTX and device code
+ #
+ # The default behavior for a non-native is to build virtual architectures as needed to cover all features needed
+ # for best performance and to also build real architectures for the most commonly used GPUs.
+ if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24")
+ set(CMAKE_CUDA_ARCHITECTURES "native")
+ else()
+ if (CUDAToolkit_VERSION VERSION_LESS "13")
+ list(APPEND CMAKE_CUDA_ARCHITECTURES 50-virtual 61-virtual 70-virtual)
+ endif ()
+
+ list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real)
+
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
+ list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
+ endif()
+
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
+ # The CUDA architecture 120f-virtual would in principle work for Blackwell support
+ # but the newly added "f" suffix conflicted with a preexising regex for validating CUDA architectures in CMake.
+ # So either a recent CMake version or one with the backported fix is needed.
+ # The following versions should work:
+ # - CMake >= v3.31.8 && CMake < v4.0.0
+ # - CMake >= v4.0.2
+ # This is NOT documented in the CMake release notes,
+ # check Modules/Internal/CMakeCUDAArchitecturesValidate.cmake in the CMake git repository instead.
+ # However, the architectures 120a-real and 121a-real should work with basically any CMake version and
+ # until the release of e.g. Rubin there is no benefit to shipping virtual architectures for Blackwell.
+ list(APPEND CMAKE_CUDA_ARCHITECTURES 120a-real)
+ endif()
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.9")
+ list(APPEND CMAKE_CUDA_ARCHITECTURES 121a-real)
+ endif()
+ endif()
+ endif()
+
+ enable_language(CUDA)
+
+ # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
+ if (GGML_CUDA_CUB_3DOT2)
+ include(FetchContent)
+
+ FetchContent_Declare(
+ CCCL
+ GIT_REPOSITORY https://github.com/nvidia/cccl.git
+ GIT_TAG v3.2.0
+ GIT_SHALLOW TRUE
+ )
+
+ FetchContent_MakeAvailable(CCCL)
+ endif()
+
+ # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
+ # 12X is forwards-compatible, 12Xa is not.
+ # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
+ # But while 12X vs. 12Xa can be checked in device code there is (to my knowledge) no easy way to do the same check in host code.
+ # So for now just replace all instances of 12X with 12Xa, this should be fine until Rubin is released.
+ foreach(ARCHS IN ITEMS CMAKE_CUDA_ARCHITECTURES CMAKE_CUDA_ARCHITECTURES_NATIVE)
+ set(FIXED_ARCHS "")
+ foreach(ARCH IN LISTS ${ARCHS})
+ if (ARCH MATCHES "^12[0-9](-real|-virtual)?$")
+ string(REGEX REPLACE "^(12[0-9])((-real|-virtual)?)$" "\\1a\\2" FIXED_ARCH ${ARCH})
+ message(STATUS "Replacing ${ARCH} in ${ARCHS} with ${FIXED_ARCH}")
+ list(APPEND FIXED_ARCHS "${FIXED_ARCH}")
+ else()
+ list(APPEND FIXED_ARCHS "${ARCH}")
+ endif()
+ endforeach()
+ set(${ARCHS} ${FIXED_ARCHS})
+ endforeach()
+
+ # If we try to compile a "native" build it will use the 12X architectures and fail.
+ # So we should instead use the native architectures as determined by CMake after replacing 12X with 12Xa.
+ # But if at the time of the build no GPUs are connected at all CMAKE_CUDA_ARCHITECTURES will contain garbage that we should not use.
+ if (CMAKE_CUDA_ARCHITECTURES STREQUAL "native" AND CMAKE_CUDA_ARCHITECTURES_NATIVE MATCHES "^[0-9]+(a|f)?(-real|-virtual)?(;[0-9]+(a|f)?(-real|-virtual)?|;)*$")
+ set(CMAKE_CUDA_ARCHITECTURES ${CMAKE_CUDA_ARCHITECTURES_NATIVE})
+ endif()
+ message(STATUS "Using CMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} CMAKE_CUDA_ARCHITECTURES_NATIVE=${CMAKE_CUDA_ARCHITECTURES_NATIVE}")
+
+ file(GLOB GGML_HEADERS_CUDA "*.cuh")
+ list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
+
+ file(GLOB GGML_SOURCES_CUDA "*.cu")
+ file(GLOB SRCS "template-instances/fattn-tile*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "template-instances/fattn-mma*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "template-instances/mmq*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "template-instances/mmf*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+
+ if (GGML_CUDA_FA_ALL_QUANTS)
+ file(GLOB SRCS "template-instances/fattn-vec*.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
+ else()
+ file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
+ endif()
+
+ ggml_add_backend_library(ggml-cuda
+ ${GGML_HEADERS_CUDA}
+ ${GGML_SOURCES_CUDA}
+ )
+
+ add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
+
+ if (GGML_CUDA_GRAPHS)
+ add_compile_definitions(GGML_CUDA_USE_GRAPHS)
+ endif()
+
+ if (GGML_CUDA_FORCE_MMQ)
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
+ endif()
+
+ if (GGML_CUDA_FORCE_CUBLAS)
+ add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
+ endif()
+
+ if (GGML_CUDA_NO_VMM)
+ add_compile_definitions(GGML_CUDA_NO_VMM)
+ endif()
+
+ if (NOT GGML_CUDA_FA)
+ add_compile_definitions(GGML_CUDA_NO_FA)
+ endif()
+
+ if (GGML_CUDA_NO_PEER_COPY)
+ add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
+ endif()
+
+ if (GGML_STATIC)
+ if (WIN32)
+ # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
+ else ()
+ if (GGML_CUDA_CUB_3DOT2)
+ target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
+ endif()
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ else()
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static)
+ endif()
+ endif()
+ else()
+ if (GGML_CUDA_CUB_3DOT2)
+ target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
+ endif()
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
+ endif()
+
+ if (GGML_CUDA_NO_VMM)
+ # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
+ else()
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)
+ endif()
+
+ set(CUDA_CXX_FLAGS "")
+
+ set(CUDA_FLAGS -use_fast_math -extended-lambda)
+
+ if (GGML_CUDA_DEBUG)
+ list(APPEND CUDA_FLAGS -lineinfo)
+ add_compile_definitions(GGML_CUDA_DEBUG)
+ endif()
+
+ if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
+ # Options are:
+ # - none (not recommended)
+ # - speed (nvcc's default)
+ # - balance
+ # - size
+ list(APPEND CUDA_FLAGS -compress-mode=${GGML_CUDA_COMPRESSION_MODE})
+ endif()
+
+ if (GGML_FATAL_WARNINGS)
+ list(APPEND CUDA_FLAGS -Werror all-warnings)
+ endif()
+
+ if (GGML_ALL_WARNINGS AND NOT MSVC)
+ set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
+ if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
+ list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
+ endif()
+
+ execute_process(
+ COMMAND ${NVCC_CMD} -Xcompiler --version
+ OUTPUT_VARIABLE CUDA_CCFULLVER
+ ERROR_QUIET
+ )
+
+ if (NOT CUDA_CCFULLVER MATCHES clang)
+ set(CUDA_CCID "GNU")
+ execute_process(
+ COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
+ OUTPUT_VARIABLE CUDA_CCVER
+ ERROR_QUIET
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ )
+ else()
+ if (CUDA_CCFULLVER MATCHES Apple)
+ set(CUDA_CCID "AppleClang")
+ else()
+ set(CUDA_CCID "Clang")
+ endif()
+ string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
+ endif()
+
+ message(STATUS "CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
+
+ ggml_get_flags(${CUDA_CCID} ${CUDA_CCVER})
+ list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
+ endif()
+
+ if (NOT MSVC)
+ list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
+ else()
+ # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
+ # https://github.com/NVIDIA/cccl/pull/6827
+ list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
+ endif()
+
+ list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
+
+ if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
+ list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
+ endif()
+
+ target_compile_options(ggml-cuda PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
+else()
+ message(FATAL_ERROR "CUDA Toolkit not found")
+endif()
diff --git a/llama.cpp/ggml/src/ggml-cuda/acc.cu b/llama.cpp/ggml/src/ggml-cuda/acc.cu
new file mode 100644
index 0000000..e084607
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/acc.cu
@@ -0,0 +1,61 @@
+#include "acc.cuh"
+
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int64_t ne,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+ const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) {
+ const int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ int64_t src1_idx = i - offset;
+
+ int64_t tmp = src1_idx;
+ const int64_t i13 = tmp / s13;
+ tmp -= i13 * s13;
+ const int64_t i12 = tmp / s12;
+ tmp -= i12 * s12;
+ const int64_t i11 = tmp / s11;
+ tmp -= i11 * s11;
+ const int64_t i10 = tmp;
+
+ float val = x[i];
+ if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) {
+ val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10];
+ }
+ dst[i] = val;
+}
+
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int64_t n_elements,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+ const int64_t s1, const int64_t s2, const int64_t s3, const int64_t offset, cudaStream_t stream) {
+ const int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+ acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset);
+}
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(dst->nb[0] == ggml_element_size(dst));
+ GGML_ASSERT(ggml_is_contiguously_allocated(dst));
+
+ const int64_t s1 = dst->op_params[0] / sizeof(float);
+ const int64_t s2 = dst->op_params[1] / sizeof(float);
+ const int64_t s3 = dst->op_params[2] / sizeof(float);
+ const int64_t offset = dst->op_params[3] / sizeof(float);
+
+ acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], s1, s2, s3, offset, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/acc.cuh b/llama.cpp/ggml/src/ggml-cuda/acc.cuh
new file mode 100644
index 0000000..1168ea1
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/acc.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ACC_BLOCK_SIZE 256
+
+void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/add-id.cu b/llama.cpp/ggml/src/ggml-cuda/add-id.cu
new file mode 100644
index 0000000..8d9cf69
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/add-id.cu
@@ -0,0 +1,58 @@
+#include "add-id.cuh"
+
+static __global__ void add_id_kernel(
+ const float * src0, const float * src1, const int32_t * src2, float * dst,
+ int64_t ne0, int64_t ne1,
+ size_t nb01, size_t nb02,
+ size_t nb11,
+ size_t nb21
+ ) {
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.y;
+
+ const int i11 = *(const int32_t *) ((const char *) src2 + i1*sizeof(int32_t) + i2*nb21);
+
+ const size_t nb1 = ne0 * sizeof(float);
+ const size_t nb2 = ne1 * nb1;
+
+ float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
+ const float * src0_row = (const float *)((const char *)src0 + i1*nb01 + i2*nb02);
+ const float * src1_row = (const float *)((const char *)src1 + i11*nb11);
+
+ for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
+ }
+}
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ GGML_TENSOR_TERNARY_OP_LOCALS
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+ GGML_ASSERT(nb20 == sizeof(int32_t));
+
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ const int32_t * src2_d = (const int32_t *)src2->data;
+ float * dst_d = (float *)dst->data;
+
+ int threads = std::min((int)ne00, 768); // cols
+ dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
+ add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
+ src0_d, src1_d, src2_d, dst_d,
+ ne0, ne1,
+ nb01, nb02,
+ nb11,
+ nb21
+ );
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/add-id.cuh b/llama.cpp/ggml/src/ggml-cuda/add-id.cuh
new file mode 100644
index 0000000..30b1721
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/add-id.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/arange.cu b/llama.cpp/ggml/src/ggml-cuda/arange.cu
new file mode 100644
index 0000000..b5e495a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/arange.cu
@@ -0,0 +1,34 @@
+#include "arange.cuh"
+
+static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
+ // blockIDx.x: idx of ne0 / BLOCK_SIZE
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+ dst[nidx] = start + step * nidx;
+}
+
+static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
+ arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
+}
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ float start;
+ float stop;
+ float step;
+ memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
+ memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
+ memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
+
+ int64_t steps = (int64_t)ceil((stop - start) / step);
+ GGML_ASSERT(ggml_nelements(dst) == steps);
+
+ arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/arange.cuh b/llama.cpp/ggml/src/ggml-cuda/arange.cuh
new file mode 100644
index 0000000..41e74fd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/arange.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ARANGE_BLOCK_SIZE 256
+
+void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/argmax.cu b/llama.cpp/ggml/src/ggml-cuda/argmax.cu
new file mode 100644
index 0000000..51967c6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/argmax.cu
@@ -0,0 +1,91 @@
+#include <algorithm>
+#include <cstdint>
+
+#include "argmax.cuh"
+#include "common.cuh"
+#include "sum.cuh"
+
+static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) {
+ const int64_t row = blockIdx.x;
+
+ float maxval = -FLT_MAX;
+ int argmax = -1;
+ const float * rowx = x + row * ncols;
+
+ for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) {
+ const float val = rowx[col];
+ if (val > maxval) {
+ maxval = val;
+ argmax = col;
+ }
+ }
+
+#pragma unroll
+ for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
+ const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+ const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+ if (val > maxval) {
+ maxval = val;
+ argmax = col;
+ }
+ }
+
+ const int n_warps = blockDim.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ if (n_warps > 1) {
+ constexpr int max_warps = 1024 / WARP_SIZE;
+ __shared__ float shared_maxval[max_warps];
+ __shared__ int shared_argmax[max_warps];
+ if (lane_id == 0) {
+ shared_maxval[warp_id] = maxval;
+ shared_argmax[warp_id] = argmax;
+ }
+
+ __syncthreads();
+
+ if (warp_id == 0) {
+ if (lane_id < n_warps) {
+ maxval = shared_maxval[lane_id];
+ argmax = shared_argmax[lane_id];
+ }
+#pragma unroll
+ for (int offset = WARP_SIZE/2; offset > 0; offset >>= 1) {
+ const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE);
+ const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE);
+ if (val > maxval) {
+ maxval = val;
+ argmax = col;
+ }
+ }
+ }
+ }
+
+ if (warp_id == 0 && lane_id == 0) {
+ dst[row] = argmax;
+ }
+}
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ const float * src0_d = (const float *) src0->data;
+ int32_t * dst_d = (int32_t *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ const int64_t num_blocks = nrows;
+ const int64_t num_threads = std::min<int64_t>(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE);
+ const dim3 blocks_dim(num_threads, 1, 1);
+ const dim3 blocks_num(num_blocks, 1, 1);
+
+ argmax_f32<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, dst_d, ne00);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/argmax.cuh b/llama.cpp/ggml/src/ggml-cuda/argmax.cuh
new file mode 100644
index 0000000..5b7223a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/argmax.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/argsort.cu b/llama.cpp/ggml/src/ggml-cuda/argsort.cu
new file mode 100644
index 0000000..4896669
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/argsort.cu
@@ -0,0 +1,230 @@
+#include "argsort.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+# include <cub/cub.cuh>
+# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1)
+# define STRIDED_ITERATOR_AVAILABLE
+# endif
+using namespace cub;
+#endif // GGML_CUDA_USE_CUB
+
+static __global__ void init_indices(int * indices, const int ncols, const int nrows) {
+ const int col = blockIdx.x * blockDim.x + threadIdx.x;
+ const int row = blockIdx.y;
+
+ if (col < ncols && row < nrows) {
+ indices[row * ncols + col] = col;
+ }
+}
+
+#ifndef STRIDED_ITERATOR_AVAILABLE
+static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx <= nrows) {
+ offsets[idx] = idx * ncols;
+ }
+}
+#endif // STRIDED_ITERATOR_AVAILABLE
+
+#ifdef GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+ const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream) {
+ ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
+ ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
+
+ int * temp_indices = temp_indices_alloc.get();
+ float * temp_keys = temp_keys_alloc.get();
+
+ static const int block_size = 256;
+ const dim3 grid_size((ncols + block_size - 1) / block_size, nrows);
+ init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows);
+
+#ifdef STRIDED_ITERATOR_AVAILABLE
+ auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols);
+#else
+ ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
+ int * offset_iterator = offsets_alloc.get();
+ const dim3 offset_grid((nrows + block_size - 1) / block_size);
+ init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows);
+#endif
+ CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream));
+
+ size_t temp_storage_bytes = 0;
+
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols * nrows, nrows, // num items, num segments
+ offset_iterator, offset_iterator + 1, stream);
+ }
+ } else {
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
+ dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1,
+ stream);
+ }
+ }
+
+ ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
+ void * d_temp_storage = temp_storage_alloc.get();
+
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
+ ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream);
+ }
+ } else {
+ if (nrows == 1) {
+ DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
+ temp_indices, dst, // values (indices)
+ ncols, 0, sizeof(float) * 8, stream);
+ } else {
+ DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
+ temp_indices, dst, ncols * nrows, nrows, offset_iterator,
+ offset_iterator + 1, stream);
+ }
+ }
+}
+#endif // GGML_CUDA_USE_CUB
+
+// Bitonic sort implementation
+template<typename T>
+static inline __device__ void ggml_cuda_swap(T & a, T & b) {
+ T tmp = a;
+ a = b;
+ b = tmp;
+}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
+ // bitonic sort
+ int col = threadIdx.x;
+ int row = blockIdx.x;
+
+ if (col >= ncols_pad) {
+ return;
+ }
+
+ const float * x_row = x + row * ncols;
+ extern __shared__ int dst_row[];
+
+ // initialize indices
+ dst_row[col] = col;
+
+ __syncthreads();
+
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k / 2; j > 0; j /= 2) {
+ int ixj = col ^ j;
+ if (ixj > col) {
+ if ((col & k) == 0) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
+ ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+ }
+ } else {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
+ ggml_cuda_swap(dst_row[col], dst_row[ixj]);
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
+}
+
+static int next_power_of_2(int x) {
+ int n = 1;
+ while (n < x) {
+ n *= 2;
+ }
+ return n;
+}
+
+void argsort_f32_i32_cuda_bitonic(const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream) {
+ // bitonic sort requires ncols to be power of 2
+ const int ncols_pad = next_power_of_2(ncols);
+
+ const dim3 block_dims(ncols_pad, 1, 1);
+ const dim3 block_nums(nrows, 1, 1);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+
+ // FIXME: this limit could be raised by ~2-4x on Ampere or newer
+ GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
+
+ if (order == GGML_SORT_ORDER_ASC) {
+ k_argsort_f32_i32<GGML_SORT_ORDER_ASC>
+ <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else if (order == GGML_SORT_ORDER_DESC) {
+ k_argsort_f32_i32<GGML_SORT_ORDER_DESC>
+ <<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+}
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_I32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+
+#ifdef GGML_CUDA_USE_CUB
+ const int ncols_pad = next_power_of_2(ncols);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ if (shared_mem > max_shared_mem || ncols > 1024) {
+ ggml_cuda_pool & pool = ctx.pool();
+ argsort_f32_i32_cuda_cub(pool, src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ } else {
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+ }
+#else
+ argsort_f32_i32_cuda_bitonic(src0_d, (int *) dst_d, ncols, nrows, order, stream);
+#endif
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/argsort.cuh b/llama.cpp/ggml/src/ggml-cuda/argsort.cuh
new file mode 100644
index 0000000..22b7306
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/argsort.cuh
@@ -0,0 +1,19 @@
+#include "common.cuh"
+
+void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+#ifdef GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
+ const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream);
+#endif // GGML_CUDA_USE_CUB
+void argsort_f32_i32_cuda_bitonic(const float * x,
+ int * dst,
+ const int ncols,
+ const int nrows,
+ ggml_sort_order order,
+ cudaStream_t stream);
diff --git a/llama.cpp/ggml/src/ggml-cuda/binbcast.cu b/llama.cpp/ggml/src/ggml-cuda/binbcast.cu
new file mode 100644
index 0000000..7339fe0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/binbcast.cu
@@ -0,0 +1,504 @@
+#include "binbcast.cuh"
+#include <cstdint>
+#include <utility>
+
+static __device__ __forceinline__ float op_repeat(const float a, const float b) {
+ return b;
+ GGML_UNUSED(a);
+}
+
+static __device__ __forceinline__ float op_add(const float a, const float b) {
+ return a + b;
+}
+
+static __device__ __forceinline__ float op_sub(const float a, const float b) {
+ return a - b;
+}
+
+static __device__ __forceinline__ float op_mul(const float a, const float b) {
+ return a * b;
+}
+
+static __device__ __forceinline__ float op_div(const float a, const float b) {
+ return a / b;
+}
+
+template <float (*bin_op)(const float, const float),
+ typename src0_t,
+ typename src1_t,
+ typename dst_t,
+ typename... src1_ptrs>
+static __global__ void k_bin_bcast(const src0_t * src0,
+ const src1_t * src1,
+ dst_t * dst,
+ const int ne0,
+ const int ne1,
+ const int ne2,
+ const uint3 ne3,
+ const uint3 ne10,
+ const uint3 ne11,
+ const uint3 ne12,
+ const uint3 ne13,
+ /*const int s0,*/
+ const int s1,
+ const int s2,
+ const int s3,
+ const int s00,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s10,
+ const int s11,
+ const int s12,
+ const int s13,
+ src1_ptrs... src1s) {
+ const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
+ const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
+ const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
+ const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
+
+ if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
+ return;
+ }
+
+ const uint32_t i11 = fastmodulo(i1, ne11);
+ const uint32_t i12 = fastmodulo(i2, ne12);
+ const uint32_t i13 = fastmodulo(i3, ne13);
+
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
+
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
+ dst_t * dst_row = dst + i_dst;
+
+ for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
+ const uint32_t i10 = fastmodulo(i0, ne10);
+
+ float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10*s10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
+ }
+}
+
+template <float (*bin_op)(const float, const float),
+ typename src0_t,
+ typename src1_t,
+ typename dst_t,
+ typename... src1_ptrs>
+static __global__ void k_bin_bcast_unravel(const src0_t * src0,
+ const src1_t * src1,
+ dst_t * dst,
+ const uint3 ne0,
+ const uint3 ne1,
+ const uint3 ne2,
+ const uint32_t ne3,
+ const uint3 prod_012,
+ const uint3 prod_01,
+ const uint3 ne10,
+ const uint3 ne11,
+ const uint3 ne12,
+ const uint3 ne13,
+ /*const int s0,*/
+ const int s1,
+ const int s2,
+ const int s3,
+ const int s00,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s10,
+ const int s11,
+ const int s12,
+ const int s13,
+ src1_ptrs... src1s) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ const uint32_t i3 = fastdiv(i, prod_012);
+ const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
+ const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
+ const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
+
+ if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
+ return;
+ }
+
+ const int i11 = fastmodulo(i1, ne11);
+ const int i12 = fastmodulo(i2, ne12);
+ const int i13 = fastmodulo(i3, ne13);
+
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
+ const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
+
+ const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
+ dst_t * dst_row = dst + i_dst;
+
+ const int i10 = fastmodulo(i0, ne10);
+
+ float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
+ if constexpr (sizeof...(src1_ptrs) > 0) {
+ result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
+ } else {
+ result = bin_op(result, (float)src1[i_src1 + i10*s10]);
+ }
+
+ dst_row[i0] = (dst_t) result;
+}
+
+template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
+static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream, std::index_sequence<I...>) {
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ int nr0 = ne10 / ne0;
+ int nr1 = ne11 / ne1;
+ int nr2 = ne12 / ne2;
+ int nr3 = ne13 / ne3;
+
+ int nr[4] = { nr0, nr1, nr2, nr3 };
+
+ int64_t cne[] = { ne0, ne1, ne2, ne3 };
+ int64_t cne0[] = { ne00, ne01, ne02, ne03 };
+ int64_t cne1[] = { ne10, ne11, ne12, ne13 };
+
+ size_t cnb[] = { nb0, nb1, nb2, nb3 };
+ size_t cnb0[] = { nb00, nb01, nb02, nb03 };
+ size_t cnb1[] = { nb10, nb11, nb12, nb13 };
+
+ auto collapse = [](int64_t cne[]) {
+ cne[0] *= cne[1];
+ cne[1] = cne[2];
+ cne[2] = cne[3];
+ cne[3] = 1;
+ };
+
+ auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
+ cnb[1] *= cne[1];
+ cnb[2] *= cne[2];
+ cnb[3] *= cne[3];
+ };
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
+ for (int i = 0; i < 4; i++) {
+ if (nr[i] != 1) {
+ break;
+ }
+ if (i > 0) {
+ collapse_nb(cnb, cne);
+ collapse_nb(cnb0, cne0);
+ collapse_nb(cnb1, cne1);
+ collapse(cne);
+ collapse(cne0);
+ collapse(cne1);
+ }
+ }
+ }
+
+ {
+ int64_t ne0 = cne[0];
+ int64_t ne1 = cne[1];
+ int64_t ne2 = cne[2];
+ int64_t ne3 = cne[3];
+
+ //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
+ //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
+ //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
+ //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
+
+ size_t nb0 = cnb[0];
+ size_t nb1 = cnb[1];
+ size_t nb2 = cnb[2];
+ size_t nb3 = cnb[3];
+
+ size_t nb00 = cnb0[0];
+ size_t nb01 = cnb0[1];
+ size_t nb02 = cnb0[2];
+ size_t nb03 = cnb0[3];
+
+ size_t nb10 = cnb1[0];
+ size_t nb11 = cnb1[1];
+ size_t nb12 = cnb1[2];
+ size_t nb13 = cnb1[3];
+
+ //size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
+
+ size_t s10 = nb10 / sizeof(src1_t);
+ size_t s11 = nb11 / sizeof(src1_t);
+ size_t s12 = nb12 / sizeof(src1_t);
+ size_t s13 = nb13 / sizeof(src1_t);
+
+ size_t s00 = nb00 / sizeof(src0_t);
+ size_t s01 = nb01 / sizeof(src0_t);
+ size_t s02 = nb02 / sizeof(src0_t);
+ size_t s03 = nb03 / sizeof(src0_t);
+
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
+
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
+
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
+
+ const int block_size = 128;
+
+ int64_t hne0 = std::max(ne0 / 2LL, 1LL);
+
+ dim3 block_dims;
+ block_dims.x = std::min<unsigned int>(hne0, block_size);
+ block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
+ block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
+
+ dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
+ (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
+
+ const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
+ const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
+ const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
+ const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
+
+ if (block_nums.z > 65535 || block_nums.y > 65535) {
+ int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
+ const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
+ const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
+ const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
+ const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
+ const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
+
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
+ ne12, ne13,
+ /*s0,*/ s1, s2, s3,
+ s00, s01, s02, s03,
+ s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
+ <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
+ ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
+ /*s0,*/ s1, s2, s3,
+ s00, s01, s02, s03,
+ s10, s11, s12, s13);
+ }
+ } else {
+ const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
+ if constexpr (sizeof...(I) > 0) {
+ k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+ /*s0,*/ s1, s2, s3,
+ s00 ,s01, s02, s03,
+ s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
+ } else {
+ k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
+ /*s0,*/ s1, s2, s3,
+ s00, s01, s02, s03,
+ s10, s11, s12, s13);
+ }
+ }
+ }
+}
+
+template <typename T>
+static __global__ void k_repeat_back(
+ const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
+
+ const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
+ const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
+ const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
+ const int64_t tid2 = tid23 % ne2;
+ const int64_t tid3 = tid23 / ne2;
+
+ if (tid0 >= ne0) {
+ return;
+ }
+
+ T sum = 0;
+ for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
+ for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
+ for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
+ for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
+ sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
+ }
+ }
+ }
+ }
+ dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
+}
+
+template <float (*bin_op)(const float, const float), int n_fuse = 1>
+struct bin_bcast_cuda {
+ template<typename src0_t, typename src1_t, typename dst_t>
+ void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
+ const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
+ cudaStream_t stream) {
+ launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
+ src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
+ }
+};
+
+template <typename T>
+static void repeat_back_cuda(
+ const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const size_t s00, const size_t s01, const size_t s02, const size_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
+
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
+ k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
+ (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
+}
+
+template<class op>
+static void ggml_cuda_op_bin_bcast(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
+}
+
+template <float (*op)(const float, const float), int n_fuse>
+static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ cudaStream_t stream = ctx.stream();
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
+ (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
+ (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
+ launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
+ (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
+ stream, std::make_index_sequence<n_fuse>{});
+ } else {
+ fprintf(stderr,
+ "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
+ __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
+ GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
+
+ switch (n_fuse) {
+ case 2:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
+ break;
+ case 3:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
+ break;
+ case 4:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
+ break;
+ case 5:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
+ break;
+ case 6:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
+ break;
+ case 7:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
+ break;
+ case 8:
+ ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
+ break;
+ default:
+ GGML_ASSERT(false && "Unsupported n_fuse value");
+ }
+}
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->type == dst->type);
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_can_repeat(dst, src0));
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ GGML_ASSERT(ne2*ne3 <= (1 << 15));
+
+ const size_t ts = ggml_type_size(src0->type);
+ const size_t s00 = nb00 / ts;
+ const size_t s01 = nb01 / ts;
+ const size_t s02 = nb02 / ts;
+ const size_t s03 = nb03 / ts;
+
+ switch (dst->type) {
+ case GGML_TYPE_F32: {
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
+ } break;
+ default: {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh b/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh
new file mode 100644
index 0000000..62bc950
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/binbcast.cuh
@@ -0,0 +1,11 @@
+#include "common.cuh"
+
+void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
diff --git a/llama.cpp/ggml/src/ggml-cuda/clamp.cu b/llama.cpp/ggml/src/ggml-cuda/clamp.cu
new file mode 100644
index 0000000..fe415e7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/clamp.cu
@@ -0,0 +1,45 @@
+#include "clamp.cuh"
+
+static __device__ __forceinline__ float op_clamp(float x, float min, float max) {
+ return fminf(fmaxf(x, min), max);
+}
+
+template <class T>
+static __global__ void op_clamp_kernel(const T * x, T * dst, const T min, const T max, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = (T)op_clamp((float)x[i], (float)min, (float)max);
+}
+
+template <class T>
+static void clamp_cuda(const T * x, T * dst, const T min, const T max, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
+ op_clamp_kernel<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
+}
+
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const void * src0_d = src0->data;
+ void * dst_d = dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ float min;
+ float max;
+ memcpy(&min, dst->op_params, sizeof(float));
+ memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+
+ if (src0->type == GGML_TYPE_F16) {
+ clamp_cuda((const half *)src0_d, (half *)dst_d, (half)min, (half)max, ggml_nelements(src0), stream);
+ } else {
+ clamp_cuda((const float *)src0_d, (float *)dst_d, (float)min, (float)max, ggml_nelements(src0), stream);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/clamp.cuh b/llama.cpp/ggml/src/ggml-cuda/clamp.cuh
new file mode 100644
index 0000000..7f9559d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/clamp.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CLAMP_BLOCK_SIZE 256
+
+void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/common.cuh b/llama.cpp/ggml/src/ggml-cuda/common.cuh
new file mode 100644
index 0000000..a3256d5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/common.cuh
@@ -0,0 +1,1440 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-impl.h"
+#include "ggml-cuda.h"
+
+#include <cstdint>
+#include <memory>
+
+#if defined(GGML_USE_HIP)
+#define GGML_COMMON_DECL_HIP
+#define GGML_COMMON_IMPL_HIP
+#else
+#define GGML_COMMON_DECL_CUDA
+#define GGML_COMMON_IMPL_CUDA
+#if defined(GGML_USE_MUSA)
+#define GGML_COMMON_DECL_MUSA
+#define GGML_COMMON_IMPL_MUSA
+#endif
+#endif
+#include "ggml-common.h"
+
+#include <array>
+#include <algorithm>
+#include <cassert>
+#include <cfloat>
+#include <cstdio>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#if defined(GGML_USE_HIP)
+#include "vendors/hip.h"
+#elif defined(GGML_USE_MUSA)
+#include "vendors/musa.h"
+#else
+#include "vendors/cuda.h"
+#endif // defined(GGML_USE_HIP)
+
+#define STRINGIZE_IMPL(...) #__VA_ARGS__
+#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
+
+#define WARP_SIZE 32
+#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
+#define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
+
+#define GGML_CUDA_CC_PASCAL 600
+#define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA 700
+#define GGML_CUDA_CC_TURING 750
+#define GGML_CUDA_CC_AMPERE 800
+#define GGML_CUDA_CC_ADA_LOVELACE 890
+// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
+// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
+#define GGML_CUDA_CC_BLACKWELL 1200
+#define GGML_CUDA_CC_DGX_SPARK 1210
+#define GGML_CUDA_CC_RUBIN 1300
+#define GGML_CUDA_CC_OFFSET_AMD 0x1000000
+#define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
+#define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
+
+// AMD
+// GCN/CDNA, wave size is 64
+#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
+#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
+#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
+#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
+#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
+#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
+
+// RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
+#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
+#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
+#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA3_5 (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops.
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
+
+#define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
+#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
+#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5)
+#define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4)
+#define GGML_CUDA_CC_IS_RDNA3(cc) (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc))
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
+#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
+#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
+#define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
+#define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
+#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
+
+// Moore Threads
+#define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons
+
+#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
+#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
+#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
+
+#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
+#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
+#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
+#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+# define GGML_CUDA_USE_CUB
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+
+#ifdef __CUDA_ARCH_LIST__
+constexpr bool ggml_cuda_has_arch_impl(int) {
+ return false;
+}
+
+template<class ... Archs>
+constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
+ return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
+}
+
+constexpr bool ggml_cuda_has_arch(const int arch) {
+ return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
+}
+
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
+ if (cur == 0) {
+ return -1;
+ }
+ return cur;
+}
+
+template<class ... Archs>
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
+ if (first <= arch && first > cur) {
+ return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
+ } else {
+ return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
+ }
+}
+
+constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
+ return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
+}
+#else
+static int ggml_cuda_highest_compiled_arch(const int arch) {
+ return arch;
+}
+#endif // __CUDA_ARCH_LIST__
+
+// ---------------------------------------------------------------------------------------------------------
+
+#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
+
+#define GGML_CUDA_MAX_STREAMS 8
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
+
+#define CUDA_CHECK_GEN(err, success, error_fn) \
+ do { \
+ auto err_ = (err); \
+ if (err_ != (success)) { \
+ ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
+ } \
+ } while (0)
+
+#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
+
+#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
+ static const char * cublas_get_error_str(const cublasStatus_t err) {
+ return cublasGetStatusString(err);
+ }
+#else
+ static const char * cublas_get_error_str(const cublasStatus_t err) {
+ switch (err) {
+ case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
+ case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
+ case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
+ default: return "unknown error";
+ }
+ }
+#endif // CUDART_VERSION >= 12000
+
+#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
+static const char * cu_get_error_str(CUresult err) {
+ const char * err_str;
+ cuGetErrorString(err, &err_str);
+ return err_str;
+}
+#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
+#endif
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
+ do { \
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
+ const int id = ggml_cuda_get_device(); \
+ if (!shared_memory_limit_raised[id]) { \
+ CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
+ shared_memory_limit_raised[id] = true; \
+ } \
+ } while (0)
+#else
+# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
+ do { \
+ GGML_UNUSED(nbytes); \
+ } while (0)
+#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
+#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
+#else
+#define GGML_CUDA_ASSUME(x)
+#endif // CUDART_VERSION >= 11010
+
+#if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
+#define GGML_USE_VMM
+#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
+
+#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#define FP16_AVAILABLE
+#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+
+#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+#define FAST_FP16_AVAILABLE
+#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
+
+#if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
+#define AMD_MFMA_AVAILABLE
+#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
+
+#if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3))
+#define AMD_WMMA_AVAILABLE
+#endif // defined(GGML_USE_HIP) && defined(RDNA4)
+
+// The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#define VOLTA_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+#define TURING_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define AMPERE_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
+# define BLACKWELL_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
+
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#define CP_ASYNC_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
+#define FLASH_ATTN_AVAILABLE
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
+
+#if defined(TURING_MMA_AVAILABLE)
+#define LDMATRIX_TRANS_AVAILABLE
+#endif // defined(TURING_MMA_AVAILABLE)
+
+static bool fp16_available(const int cc) {
+ return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
+}
+
+static bool fast_fp16_available(const int cc) {
+ return GGML_CUDA_CC_IS_AMD(cc) ||
+ (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fast_fp16_hardware_available(const int cc) {
+ return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fp16_mma_hardware_available(const int cc) {
+ return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
+ GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
+}
+
+static bool bf16_mma_hardware_available(const int cc) {
+ return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
+ GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
+}
+
+static bool fp32_mma_hardware_available(const int cc) {
+ return GGML_CUDA_CC_IS_CDNA(cc);
+}
+
+static bool amd_mfma_available(const int cc) {
+#if !defined(GGML_HIP_NO_MMQ_MFMA)
+ return GGML_CUDA_CC_IS_CDNA(cc);
+#else
+ return false;
+#endif //!defined(GGML_HIP_NO_MMQ_MFMA)
+}
+
+static bool amd_wmma_available(const int cc) {
+ return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc));
+}
+
+static bool volta_mma_available(const int cc) {
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
+}
+
+static bool turing_mma_available(const int cc) {
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
+}
+
+static bool ampere_mma_available(const int cc) {
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
+static bool cp_async_available(const int cc) {
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
+}
+
+static bool blackwell_mma_available(const int cc) {
+ return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
+ ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
+}
+
+static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
+#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
+ return 64;
+#else
+ return 32;
+#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
+}
+
+// Maximum number of bytes that can be copied in a single instruction.
+static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
+#ifdef GGML_USE_HIP
+ return 16;
+#else
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+ return 16;
+#else
+ return 8;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#endif // GGML_USE_HIP
+}
+
+
+[[noreturn]]
+static __device__ void no_device_code(
+ const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
+
+#if defined(GGML_USE_HIP)
+ printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
+ file_name, line, function_name, arch);
+ GGML_UNUSED(arch_list);
+#else
+ printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
+ file_name, line, function_name, arch, arch_list);
+#endif // defined(GGML_USE_HIP)
+ __trap();
+
+ GGML_UNUSED(no_device_code); // suppress unused function warning
+
+#if defined(GGML_USE_MUSA)
+ __builtin_unreachable();
+#endif // defined(GGML_USE_MUSA)
+}
+
+#ifdef __CUDA_ARCH__
+#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
+#else
+#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
+#endif // __CUDA_ARCH__
+
+// The compiler is always able to unroll loops if they contain continue expressions.
+// In such cases loop unrolling can still be achieved via recursion:
+template <int n>
+struct ggml_cuda_unroll {
+ template <typename Func, typename... Args>
+ __device__ void operator()(const Func & f, Args... args) const {
+ f(n - 1, args...);
+ ggml_cuda_unroll<n - 1>{}(f, args...);
+ }
+};
+
+template <>
+struct ggml_cuda_unroll<1> {
+ template <typename Func, typename... Args>
+ __device__ void operator()(const Func & f, Args... args) const {
+ f(0, args...);
+ }
+};
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ int warp_reduce_sum(int x) {
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ return __reduce_add_sync(0xffffffff, x);
+#else
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x += __shfl_xor_sync(0xffffffff, x, offset, width);
+ }
+ return x;
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x += __shfl_xor_sync(0xffffffff, x, offset, width);
+ }
+ return x;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
+ a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
+ }
+ return a;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
+#ifdef FP16_AVAILABLE
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
+ }
+ return a;
+
+#else
+ NO_DEVICE_CODE;
+ return a;
+#endif // FP16_AVAILABLE
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ int warp_reduce_all(int x) {
+ if (width == ggml_cuda_get_physical_warp_size()) {
+ return __all_sync(0xffffffff, x);
+ } else {
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
+ }
+ return x;
+ }
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ int warp_reduce_any(int x) {
+ if (width == ggml_cuda_get_physical_warp_size()) {
+ return __any_sync(0xffffffff, x);
+ } else {
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
+ }
+ return x;
+ }
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
+ }
+ return x;
+}
+
+template<typename T, int width = WARP_SIZE>
+static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
+ const int lane_id = threadIdx.x % width;
+#pragma unroll
+ for (int offset = 1; offset < width; offset <<= 1) {
+ const T t = __shfl_up_sync(0xffffffff, x, offset, width);
+ if (lane_id >= offset) {
+ x += t;
+ }
+ }
+ return x;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
+ const int lane_id = threadIdx.x % width;
+#pragma unroll
+ for (int offset = 1; offset < width; offset <<= 1) {
+ const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
+ const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
+ if (lane_id >= offset) {
+ a.x += t_x;
+ a.y += t_y;
+ }
+ }
+ return a;
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
+#ifdef FP16_AVAILABLE
+ const int lane_id = threadIdx.x % width;
+#pragma unroll
+ for (int offset = 1; offset < width; offset <<= 1) {
+ const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
+ if (lane_id >= offset) {
+ a = __hadd2(a, t);
+ }
+ }
+ return a;
+
+#else
+ NO_DEVICE_CODE;
+ return a;
+#endif // FP16_AVAILABLE
+}
+
+enum class block_reduce_method {
+ MAX,
+ SUM,
+};
+
+template<block_reduce_method method_t, typename T>
+struct block_reduce_policy;
+
+template <typename T, typename... Ts>
+inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...);
+
+template<typename...>
+inline constexpr bool ggml_cuda_dependent_false_v = false;
+
+template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> {
+ static __device__ T reduce(T val) {
+ if constexpr(is_any<T, float, float2, half2, int>) {
+ return warp_reduce_sum(val);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+ }
+ }
+
+ static __device__ T sentinel() {
+ if constexpr (std::is_same_v<T, float>) {
+ return 0.0f;
+ } else if constexpr (std::is_same_v<T, float2>) {
+ return make_float2(0.0f, 0.0f);
+ } else if constexpr (std::is_same_v<T, half2>) {
+ return make_half2(0.0f, 0.0f);
+ } else if constexpr (std::is_same_v<T, int>) {
+ return 0;
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum");
+ }
+ }
+};
+
+template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> {
+ static __device__ T reduce(T val) {
+ if constexpr (is_any<T, float, half2>) {
+ return warp_reduce_max(val);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+ }
+ }
+
+ static __device__ T sentinel() {
+ if constexpr (std::is_same_v<T, float>) {
+ return -INFINITY;
+ } else if constexpr (std::is_same_v<T, half2>) {
+ return make_half2(-INFINITY, -INFINITY);
+ } else {
+ static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max");
+ }
+ }
+};
+
+template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T>
+static __device__ T block_reduce(T val, T * shared_vals) {
+ val = block_reduce_policy<reduce_method_t, T>::reduce(val);
+ const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+ if (block_size > WARP_SIZE) {
+ assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0);
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ shared_vals[warp_id] = val;
+ }
+ __syncthreads();
+ val = block_reduce_policy<reduce_method_t, T>::sentinel();
+ if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) {
+ val = shared_vals[lane_id];
+ }
+ return block_reduce_policy<reduce_method_t, T>::reduce(val);
+ }
+
+ return val;
+}
+
+static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
+#ifdef FP16_AVAILABLE
+
+#if !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
+ return __float2half(fmaxf(__half2float(a), __half2float(b)));
+#else
+ return __hmax(a, b);
+#endif // !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
+
+#else
+ NO_DEVICE_CODE;
+ GGML_UNUSED(b);
+ return a;
+#endif // FP16_AVAILABLE
+}
+
+static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
+#if defined(GGML_USE_HIP)
+ return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
+#elif CUDART_VERSION >= CUDART_HMAX
+ return __hmax2(a, b);
+#else
+ half2 ret;
+ reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
+ reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
+ return ret;
+#endif
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
+#pragma unroll
+ for (int offset = width/2; offset > 0; offset >>= 1) {
+ x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
+ }
+ return x;
+#else
+ GGML_UNUSED(x);
+ NO_DEVICE_CODE;
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
+}
+
+#if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \
+ (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
+static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
+ const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
+ const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
+ return mask_low | mask_high;
+}
+#endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
+
+static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
+#if defined(GGML_USE_HIP)
+#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
+ c = __builtin_amdgcn_sdot4(a, b, c, false);
+#elif defined(RDNA3) || defined(RDNA4)
+ c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
+#elif defined(RDNA1) || defined(__gfx900__)
+ int tmp1;
+ int tmp2;
+ asm("\n \
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
+ v_add3_u32 %0, %1, %2, %0 \n \
+ v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
+ v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
+ v_add3_u32 %0, %1, %2, %0 \n \
+ "
+ : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
+ : "v"(a), "v"(b)
+ );
+#else
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+ c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
+#endif
+ return c;
+
+#else // defined(GGML_USE_HIP)
+
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
+ return __dp4a(a, b, c);
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
+ const int8_t * a8 = (const int8_t *) &a;
+ const int8_t * b8 = (const int8_t *) &b;
+ return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
+
+#endif // defined(GGML_USE_HIP)
+}
+
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
+ acc += v*u;
+}
+
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
+ acc += v.x*u.x;
+ acc += v.y*u.y;
+}
+
+#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
+#define V_DOT2_F32_F16_AVAILABLE
+#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
+
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
+#else
+#ifdef FAST_FP16_AVAILABLE
+ const float2 tmp = __half22float2(v*u);
+ acc += tmp.x + tmp.y;
+#else
+ const float2 tmpv = __half22float2(v);
+ const float2 tmpu = __half22float2(u);
+ acc += tmpv.x * tmpu.x;
+ acc += tmpv.y * tmpu.y;
+#endif // FAST_FP16_AVAILABLE
+#endif // V_DOT2_F32_F16_AVAILABLE
+}
+
+static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
+#ifdef FAST_FP16_AVAILABLE
+ acc += v*u;
+#else
+ const float2 tmpv = __half22float2(v);
+ const float2 tmpu = __half22float2(u);
+ float2 tmpacc = __half22float2(acc);
+ tmpacc.x += tmpv.x * tmpu.x;
+ tmpacc.y += tmpv.y * tmpu.y;
+ acc = make_half2(tmpacc.x, tmpacc.y);
+#endif // FAST_FP16_AVAILABLE
+}
+
+// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
+// Important: do not use this function if dst and src both point at registers.
+// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
+// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
+// If dst and src point at different address spaces then they are guaranteed to not be aliased.
+template <int nbytes, int alignment = 0>
+static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
+ static_assert(
+ nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,
+ "You are misusing the alignment parameter for ggml_cuda_memcpy_1. "
+ "The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. "
+ "If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. "
+ "Call ggml_cuda_memcpy_1 in a loop instead.");
+ if constexpr (alignment != 0) {
+ static_assert(nbytes % alignment == 0, "bad alignment");
+ }
+ constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
+
+#pragma unroll
+ for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
+ if constexpr (nb_per_cpy == 1) {
+ ((char *) dst)[i] = ((const char *) src)[i];
+ } else if constexpr (nb_per_cpy == 2) {
+ ((short *) dst)[i] = ((const short *) src)[i];
+ } else if constexpr (nb_per_cpy == 4) {
+ ((int *) dst)[i] = ((const int *) src)[i];
+ } else if constexpr (nb_per_cpy == 8) {
+ ((int2 *) dst)[i] = ((const int2 *) src)[i];
+ } else if constexpr (nb_per_cpy == 16) {
+ ((int4 *) dst)[i] = ((const int4 *) src)[i];
+ } else {
+ static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
+ }
+ }
+}
+
+static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
+#if CUDART_VERSION >= 12080
+ const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
+ return (float) e;
+#else
+ uint32_t bits;
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = (uint32_t) x << 23;
+ }
+
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+#endif // CUDART_VERSION >= 12050
+}
+
+__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
+ const uint8_t sign_bit = (x < 0.0f) << 3;
+ float ax = fabsf(x) * e;
+
+ // Positive LUT
+ static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
+
+ int best_i = 0;
+ float best_err = fabsf(ax - pos_lut[0]);
+
+#pragma unroll
+ for (int i = 1; i < 8; ++i) {
+ const float err = fabsf(ax - pos_lut[i]);
+ if (err < best_err) {
+ best_err = err;
+ best_i = i;
+ }
+ }
+
+ return static_cast<uint8_t>(best_i | sign_bit);
+}
+
+// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
+// Precompute mp (m' in the paper) and L such that division
+// can be computed using a multiply (high 32b of 64b result)
+// and a shift:
+//
+// n/d = (mulhi(n, mp) + n) >> L;
+static const uint3 init_fastdiv_values(uint64_t d_64) {
+ GGML_ASSERT(d_64 != 0);
+ GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
+
+ uint32_t d = (uint32_t)d_64;
+
+ // compute L = ceil(log2(d));
+ uint32_t L = 0;
+ while (L < 32 && (uint32_t{ 1 } << L) < d) {
+ L++;
+ }
+
+ uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
+ // pack divisor as well to reduce error surface
+ return make_uint3(mp, L, d);
+}
+
+static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
+ // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
+ // fastdiv_values.z is unused and optimized away by the compiler.
+ // Compute high 32 bits of n * mp
+ const uint32_t hi = __umulhi(n, fastdiv_values.x);
+ // add n, apply bit shift
+ return (hi + n) >> fastdiv_values.y;
+}
+
+static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
+ // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
+ return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
+}
+
+// Calculate both division and modulo at once, returns <n/divisor, n%divisor>
+static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
+ // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
+ const uint32_t div_val = fastdiv(n, fastdiv_values);
+ const uint32_t mod_val = n - div_val * fastdiv_values.z;
+ return make_uint2(div_val, mod_val);
+}
+
+typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
+
+static __device__ __forceinline__ float get_alibi_slope(
+ const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
+) {
+ if (max_bias <= 0.0f) {
+ return 1.0f;
+ }
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ return powf(base, exph);
+}
+
+template <ggml_type type>
+struct ggml_cuda_type_traits;
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_F16> {
+ static constexpr int qk = 1;
+ static constexpr int qr = 1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
+ static constexpr int qk = QK4_0;
+ static constexpr int qr = QR4_0;
+ static constexpr int qi = QI4_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
+ static constexpr int qk = QK4_1;
+ static constexpr int qr = QR4_1;
+ static constexpr int qi = QI4_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
+ static constexpr int qk = QK5_0;
+ static constexpr int qr = QR5_0;
+ static constexpr int qi = QI5_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
+ static constexpr int qk = QK5_1;
+ static constexpr int qr = QR5_1;
+ static constexpr int qi = QI5_1;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
+ static constexpr int qk = QK8_0;
+ static constexpr int qr = QR8_0;
+ static constexpr int qi = QI8_0;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
+ static constexpr int qk = QK_MXFP4;
+ static constexpr int qr = QR_MXFP4;
+ static constexpr int qi = QI_MXFP4;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_K;
+ static constexpr int qi = QI2_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR3_K;
+ static constexpr int qi = QI3_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_K;
+ static constexpr int qi = QI4_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR5_K;
+ static constexpr int qi = QI5_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR6_K;
+ static constexpr int qi = QI6_K;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_XXS;
+ static constexpr int qi = QI2_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_XS;
+ static constexpr int qi = QI2_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR2_S;
+ static constexpr int qi = QI2_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR3_XXS;
+ static constexpr int qi = QI3_XXS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR1_S;
+ static constexpr int qi = QI1_S;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR1_M;
+ static constexpr int qi = QI1_M;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
+ static constexpr int qk = QK4_NL;
+ static constexpr int qr = QR4_NL;
+ static constexpr int qi = QI4_NL;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR4_XS;
+ static constexpr int qi = QI4_XS;
+};
+
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
+ static constexpr int qk = QK_K;
+ static constexpr int qr = QR3_S;
+ static constexpr int qi = QI3_S;
+};
+
+//////////////////////
+
+struct ggml_cuda_device_info {
+ int device_count;
+
+ struct cuda_device_info {
+ int cc; // compute capability
+ int nsm; // number of streaming multiprocessors
+ size_t smpb; // max. shared memory per block
+ size_t smpbo; // max. shared memory per block (with opt-in)
+ bool integrated; // Device is integrated as opposed to discrete
+ bool vmm; // virtual memory support
+ size_t vmm_granularity; // granularity of virtual memory
+ size_t total_vram;
+ int warp_size; // Number of threads in a dispatch
+ bool supports_cooperative_launch; // whether cooperative launch is supported
+ };
+
+ cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
+
+ std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
+};
+
+const ggml_cuda_device_info & ggml_cuda_info();
+
+void ggml_cuda_set_device(int device);
+int ggml_cuda_get_device();
+
+struct ggml_cuda_pool {
+ virtual ~ggml_cuda_pool() = default;
+
+ virtual void * alloc(size_t size, size_t * actual_size) = 0;
+ virtual void free(void * ptr, size_t size) = 0;
+};
+
+template<typename T>
+struct ggml_cuda_pool_alloc {
+ ggml_cuda_pool * pool = nullptr;
+ T * ptr = nullptr;
+ size_t actual_size = 0;
+
+ ggml_cuda_pool_alloc() = default;
+
+ explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
+ }
+
+ ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
+ alloc(size);
+ }
+
+ ~ggml_cuda_pool_alloc() {
+ if (ptr != nullptr) {
+ pool->free(ptr, actual_size);
+ }
+ }
+
+ // size is in number of elements
+ T * alloc(size_t size) {
+ GGML_ASSERT(pool != nullptr);
+ GGML_ASSERT(ptr == nullptr);
+ ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
+ return ptr;
+ }
+
+ T * alloc(ggml_cuda_pool & pool, size_t size) {
+ this->pool = &pool;
+ return alloc(size);
+ }
+
+ T * get() {
+ return ptr;
+ }
+
+ ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
+ ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
+ ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
+ ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
+};
+
+
+// backend interface
+
+struct ggml_tensor_extra_gpu {
+ void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+ cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
+};
+
+
+#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
+#define USE_CUDA_GRAPH
+#endif
+
+struct ggml_cuda_graph_node_properties {
+ void * node_data;
+ ggml_op node_op;
+ enum ggml_type node_type;
+ int32_t flags;
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS];
+ void * src_data[GGML_MAX_SRC];
+ int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
+};
+
+static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
+
+struct ggml_cuda_graph {
+#ifdef USE_CUDA_GRAPH
+ ~ggml_cuda_graph() {
+ if (instance != nullptr) {
+ CUDA_CHECK(cudaGraphExecDestroy(instance));
+ }
+ if (graph != nullptr) {
+ CUDA_CHECK(cudaGraphDestroy(graph));
+ }
+ }
+ cudaGraph_t graph = nullptr;
+ cudaGraphExec_t instance = nullptr;
+ size_t num_nodes = 0;
+ std::vector<cudaGraphNode_t> nodes;
+ bool disable_due_to_gpu_arch = false;
+ bool disable_due_to_too_many_updates = false;
+ int number_consecutive_updates = 0;
+ std::vector<ggml_cuda_graph_node_properties> props;
+
+ // these are extra tensors (inputs) that participate in the ggml graph but are not nodes
+ // they properties also have to match in order to be able to safely reuse a CUDA graph
+ // ref: https://github.com/ggml-org/llama.cpp/pull/18583
+ // ref: https://github.com/ggml-org/llama.cpp/pull/19165
+ std::vector<ggml_cuda_graph_node_properties> extra;
+
+ void record_update(bool use_graph, bool update_required) {
+ if (use_graph && update_required) {
+ number_consecutive_updates++;
+ } else {
+ number_consecutive_updates = 0;
+ }
+ if (number_consecutive_updates >= 4) {
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+ disable_due_to_too_many_updates = true;
+ }
+ }
+
+ bool is_enabled() const {
+ static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
+ return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
+ }
+#endif
+};
+
+struct ggml_cuda_concurrent_event {
+ std::vector<cudaEvent_t> join_events;
+ cudaEvent_t fork_event = nullptr;
+
+ int n_streams = 0;
+ std::unordered_map<const ggml_tensor *, int> stream_mapping;
+
+ // Original order of nodes in this concurrent region (before interleaving)
+ // Used to restore grouping for fusion within streams
+ std::vector<const ggml_tensor *> original_order;
+
+ const ggml_tensor * join_node;
+
+ ggml_cuda_concurrent_event() = default;
+
+ ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
+ ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
+
+ explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
+ join_events.resize(n_streams);
+
+ for (size_t i = 0; i < join_events.size(); ++i) {
+ CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
+ }
+
+ CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
+ }
+
+ ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
+ : join_events(std::move(other.join_events))
+ , fork_event(other.fork_event)
+ , n_streams(other.n_streams)
+ , stream_mapping(std::move(other.stream_mapping))
+ , original_order(std::move(other.original_order))
+ , join_node(other.join_node) {
+ other.fork_event = nullptr;
+ }
+
+ // 1. check if any branches write to overlapping memory ranges (except the join node)
+ // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
+ // we assume all nodes have the same buffer
+ bool is_valid() const {
+ std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
+ write_ranges.resize(n_streams);
+
+ // get join_node's memory range to exclude from overlap checking.
+ // multiple nodes can use join_node's buffer; we synchronize on the join node.
+ const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
+ const int64_t join_start = (int64_t) join_t->data;
+ const int64_t join_end = join_start + ggml_nbytes(join_t);
+
+ for (const auto & [tensor, stream] : stream_mapping) {
+ const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
+ const int64_t t_start = (int64_t) t->data;
+ const int64_t t_end = t_start + ggml_nbytes(t);
+
+ // skip tensors that overlap with join_node's buffer.
+ if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
+ continue;
+ }
+
+ // concurrent streams begin from 1
+ write_ranges[stream - 1].emplace_back(t_start, t_end);
+ }
+
+ for (int i = 0; i < n_streams; ++i) {
+ // sorts first by start then by end of write range
+ std::sort(write_ranges[i].begin(), write_ranges[i].end());
+ }
+
+ bool writes_overlap = false;
+ bool dependent_srcs = false;
+ for (const auto & [tensor, stream] : stream_mapping) {
+ const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
+ const int64_t t_start = (int64_t) t->data;
+ const int64_t t_end = t_start + ggml_nbytes(t);
+
+ // skip tensors that overlap with join_node's buffer
+ if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
+ continue;
+ }
+
+ // check if this buffer's write data overlaps with another stream's
+ std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
+ for (int i = 0; i < n_streams; ++i) {
+ if (i == stream - 1) {
+ continue;
+ }
+ auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
+
+ if (it != write_ranges[i].end()) {
+ const std::pair<int64_t, int64_t> & other = *it;
+
+ // std::lower_bound returns the first element where other >= data_range (lexicographically).
+ // This guarantees other.first >= data_range.first.
+ // Therefore, overlap occurs iff other.first < data_range.second
+ // (i.e., the other range starts before this range ends).
+ if (other.first < data_range.second) {
+ GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
+ writes_overlap = true;
+ break;
+ }
+ }
+ }
+
+ //check if all srcs are either in branch or don't have a branch
+ for (int i = 0; i < GGML_MAX_SRC; ++i) {
+ if (!tensor->src[i]) {
+ continue;
+ }
+
+ auto it = stream_mapping.find(tensor->src[i]);
+
+ if (it == stream_mapping.end()) {
+ continue;
+ }
+
+ if (it->second != stream) {
+ dependent_srcs = true;
+ break;
+ }
+ }
+
+ if (dependent_srcs || writes_overlap) {
+ break;
+ }
+ }
+
+ return !writes_overlap && !dependent_srcs;
+ }
+
+ ~ggml_cuda_concurrent_event() {
+ if (fork_event != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(fork_event));
+ }
+ for (cudaEvent_t e : join_events) {
+ if (e != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(e));
+ }
+ }
+ }
+};
+
+struct ggml_cuda_stream_context {
+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
+
+ void reset() {
+ concurrent_events.clear();
+ }
+};
+
+struct ggml_backend_cuda_context {
+ int device;
+ std::string name;
+ cudaEvent_t copy_event = nullptr;
+
+ cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
+ cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+ int curr_stream_no = 0;
+
+#ifdef USE_CUDA_GRAPH
+ // Map from first_node_ptr to cuda_graph - allows multiple graphs per context
+ // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
+ std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
+
+ ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
+ auto it = cuda_graphs.find(first_node_ptr);
+ if (it == cuda_graphs.end()) {
+ cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
+ return cuda_graphs[first_node_ptr].get();
+ }
+ return it->second.get();
+ }
+
+ // Check if any CUDA graph is enabled for this context (used by kernels that need to know
+ // if graphs are in use without having access to the specific graph key)
+ bool any_cuda_graph_enabled() const {
+ for (const auto & [key, graph] : cuda_graphs) {
+ if (graph && graph->is_enabled()) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // Check if any CUDA graph has an instance for this context
+ bool any_cuda_graph_has_instance() const {
+ for (const auto & [key, graph] : cuda_graphs) {
+ if (graph && graph->instance != nullptr) {
+ return true;
+ }
+ }
+ return false;
+ }
+#endif // USE_CUDA_GRAPH
+
+ explicit ggml_backend_cuda_context(int device) :
+ device(device),
+ name(GGML_CUDA_NAME + std::to_string(device)) {
+ }
+
+ ggml_cuda_stream_context concurrent_stream_context;
+
+ ~ggml_backend_cuda_context();
+
+ cudaStream_t stream(int device, int stream) {
+ if (streams[device][stream] == nullptr) {
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
+ }
+ return streams[device][stream];
+ }
+
+ cudaStream_t stream() { return stream(device, curr_stream_no); }
+
+ ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
+
+ cublasHandle_t cublas_handle(int device) {
+ if (cublas_handles[device] == nullptr) {
+ ggml_cuda_set_device(device);
+ CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
+ CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
+ }
+ return cublas_handles[device];
+ }
+
+ cublasHandle_t cublas_handle() {
+ return cublas_handle(device);
+ }
+
+ // pool
+ std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
+
+ static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
+
+ ggml_cuda_pool & pool(int device) {
+ if (pools[device][curr_stream_no] == nullptr) {
+ pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
+ }
+ return *pools[device][curr_stream_no];
+ }
+
+ ggml_cuda_pool & pool() {
+ return pool(device);
+ }
+};
+
+struct ggml_cuda_mm_fusion_args_host {
+ const ggml_tensor * x_bias = nullptr;
+ const ggml_tensor * gate = nullptr;
+ const ggml_tensor * gate_bias = nullptr;
+ ggml_glu_op glu_op;
+};
+struct ggml_cuda_mm_fusion_args_device {
+ const void * x_bias = nullptr;
+ const void * gate = nullptr;
+ const void * gate_bias = nullptr;
+ ggml_glu_op glu_op;
+};
diff --git a/llama.cpp/ggml/src/ggml-cuda/concat.cu b/llama.cpp/ggml/src/ggml-cuda/concat.cu
new file mode 100644
index 0000000..e9ffd27
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/concat.cu
@@ -0,0 +1,221 @@
+#include "concat.cuh"
+
+// contiguous kernels
+static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (nidx < ne00) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne00 +
+ blockIdx.z * ne00 * gridDim.y;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ (nidx - ne00) +
+ blockIdx.y * (ne0 - ne00) +
+ blockIdx.z * (ne0 - ne00) * gridDim.y;
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (blockIdx.y < (unsigned)ne01) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * ne01;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx +
+ (blockIdx.y - ne01) * ne0 +
+ blockIdx.z * ne0 * (gridDim.y - ne01);
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+
+ if (blockIdx.z < (unsigned)ne02) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ (blockIdx.z - ne02) * ne0 * gridDim.y;
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne1, ne2);
+ if (dim == 0) {
+ concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00);
+ return;
+ }
+ if (dim == 1) {
+ concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01);
+ return;
+ }
+ concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+}
+
+// non-contiguous kernel (slow)
+template <int dim>
+static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE)
+ concat_f32_non_cont(
+ const char * src0,
+ const char * src1,
+ char * dst,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne03,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ uint64_t nb03,
+ int64_t /*ne10*/,
+ int64_t /*ne11*/,
+ int64_t /*ne12*/,
+ int64_t /*ne13*/,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ uint64_t nb13,
+ int64_t ne0,
+ int64_t /*ne1*/,
+ int64_t /*ne2*/,
+ int64_t /*ne3*/,
+ uint64_t nb0,
+ uint64_t nb1,
+ uint64_t nb2,
+ uint64_t nb3){
+ static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]");
+
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i1 = blockIdx.x;
+
+ const float * x;
+
+ for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
+ } else {
+ if constexpr (dim == 0) {
+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10);
+ } else if constexpr (dim == 1) {
+ x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10);
+ } else if constexpr (dim == 2) {
+ x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10);
+ } else if constexpr (dim == 3) {
+ x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10);
+ }
+ }
+
+ float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ *y = *x;
+ }
+}
+
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ cudaStream_t stream = ctx.stream();
+
+ const int32_t dim = ((int32_t *) dst->op_params)[0];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+
+ float * dst_d = (float *)dst->data;
+
+ if (dim != 3) {
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ concat_f32_cuda(
+ src0_d + i3 * (src0->nb[3] / 4),
+ src1_d + i3 * (src1->nb[3] / 4),
+ dst_d + i3 * ( dst->nb[3] / 4),
+ src0->ne[0], src0->ne[1], src0->ne[2],
+ dst->ne[0], dst->ne[1], dst->ne[2], dim, stream);
+ }
+ } else {
+ const size_t size0 = ggml_nbytes(src0);
+ const size_t size1 = ggml_nbytes(src1);
+
+ CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream));
+ CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream));
+ }
+ } else {
+ dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
+ auto launch_kernel = [&](auto dim) {
+ concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
+ (const char *) src0->data, (const char *) src1->data, (char *) dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+ src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3]);
+ };
+ switch (dim) {
+ case 0:
+ launch_kernel(std::integral_constant<int, 0>{});
+ break;
+ case 1:
+ launch_kernel(std::integral_constant<int, 1>{});
+ break;
+ case 2:
+ launch_kernel(std::integral_constant<int, 2>{});
+ break;
+ case 3:
+ launch_kernel(std::integral_constant<int, 3>{});
+ break;
+ default:
+ GGML_ABORT("Invalid dim: %d", dim);
+ break;
+ }
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/concat.cuh b/llama.cpp/ggml/src/ggml-cuda/concat.cuh
new file mode 100644
index 0000000..aa506a0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/concat.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CONCAT_BLOCK_SIZE 256
+
+void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu
new file mode 100644
index 0000000..8418ba6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cu
@@ -0,0 +1,86 @@
+#include "conv-transpose-1d.cuh"
+
+static __global__ void conv_transpose_1d_kernel(
+ const int s0, const int p0, const int d0, const int output_size,
+ const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+ const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+ const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+ const float * src0, const float * src1, float * dst) {
+ int global_index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (global_index >= output_size) {
+ return;
+ }
+
+ int out_index = global_index / dst_ne0;
+
+ float accumulator = 0;
+
+ for (int c = 0; c < src0_ne2; c++) {
+ int idx = global_index % dst_ne0;
+
+ int kernel_offset = (src0_ne0 * src0_ne1 * c) + (out_index * src0_ne0);
+ int input_offset = src1_ne0 * c;
+
+ for (int i = 0; i < src1_ne0; i++) {
+ if (!(idx >= i*s0 && idx < i*s0 + src0_ne0)) {
+ continue;
+ }
+ int weight_idx = idx - i*s0;
+
+ float kernel_weight = src0[kernel_offset + weight_idx];
+ float input_value = src1[input_offset+i];
+
+ accumulator += kernel_weight * input_value;
+ }
+ }
+ dst[global_index] = accumulator;
+ GGML_UNUSED_VARS(p0, d0, src0_ne3, src1_ne3, dst_ne3, src1_ne1, dst_ne1, src1_ne2, dst_ne2);
+}
+
+static void conv_transpose_1d_f32_f32_cuda(
+ const int s0, const int p0, const int d0, const int output_size,
+ const int src0_ne0, const int src0_ne1, const int src0_ne2, const int src0_ne3,
+ const int src1_ne0, const int src1_ne1, const int src1_ne2, const int src1_ne3,
+ const int dst_ne0, const int dst_ne1, const int dst_ne2, const int dst_ne3,
+ const float * src0, const float * src1, float * dst,
+ cudaStream_t stream) {
+
+ const int num_blocks = (output_size + CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE;
+ conv_transpose_1d_kernel<<<num_blocks,CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE, 0, stream>>>(
+ s0,p0,d0,output_size,
+ src0_ne0, src0_ne1, src0_ne2, src0_ne3,
+ src1_ne0, src1_ne1, src1_ne2, src1_ne3,
+ dst_ne0, dst_ne1, dst_ne2, dst_ne3,
+ src0,src1, dst);
+}
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src1_d = (const float *)src1->data;
+
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+
+ const int s0 = opts[0];
+ const int p0 = 0;//opts[3];
+ const int d0 = 1;//opts[4];
+
+ const int64_t output_size = ggml_nelements(dst);
+
+ conv_transpose_1d_f32_f32_cuda(s0, p0, d0, output_size,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3],
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ src0_d, src1_d, dst_d, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cuh b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cuh
new file mode 100644
index 0000000..6c2cf66
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv-transpose-1d.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CONV_TRANPOSE_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_conv_transpose_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu
new file mode 100644
index 0000000..7583233
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cu
@@ -0,0 +1,161 @@
+#include "conv2d-dw.cuh"
+
+struct conv_params {
+ int in_w, in_h;
+ int out_w, out_h;
+ int kernel_w, kernel_h;
+ int stride_x, stride_y;
+ int padding_x, padding_y;
+ int dilation_x, dilation_y;
+ int channels, batches;
+};
+
+struct kernel_bounds {
+ int y_min, y_max;
+ int x_min, x_max;
+};
+
+__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int out_x, int out_y, const conv_params & params) {
+ kernel_bounds bounds;
+ bounds.y_min = max(0, (params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
+ bounds.y_max =
+ min(params.kernel_h,
+ (params.in_h + params.padding_y - out_y * params.stride_y + params.dilation_y - 1) / params.dilation_y);
+ bounds.x_min = max(0, (params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
+ bounds.x_max =
+ min(params.kernel_w,
+ (params.in_w + params.padding_x - out_x * params.stride_x + params.dilation_x - 1) / params.dilation_x);
+ return bounds;
+}
+
+__device__ __forceinline__ int calculate_input_coord(int out_coord, int kern_coord, int stride, int dilation, int padding) {
+ return out_coord * stride + kern_coord * dilation - padding;
+}
+
+struct whcn_layout {
+ __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
+ return n * (params.channels * params.in_w * params.in_h) + c * params.in_w * params.in_h + y * params.in_w + x;
+ }
+
+ __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
+ return c * params.kernel_h * params.kernel_w + ky * params.kernel_w + kx;
+ }
+
+ __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
+ return n * (params.channels * params.out_w * params.out_h) + c * params.out_w * params.out_h +
+ y * params.out_w + x;
+ }
+
+ __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
+ int & out_x) {
+ out_x = global_idx % params.out_w;
+ out_y = (global_idx / params.out_w) % params.out_h;
+ c = (global_idx / (params.out_w * params.out_h)) % params.channels;
+ n = global_idx / (params.out_w * params.out_h * params.channels);
+ }
+};
+
+struct cwhn_layout {
+ __device__ static int input_index(int n, int c, int y, int x, const conv_params & params) {
+ return n * (params.channels * params.in_w * params.in_h) + (y * params.in_w + x) * params.channels + c;
+ }
+
+ __device__ static int kernel_index(int c, int ky, int kx, const conv_params & params) {
+ return (ky * params.kernel_w + kx) * params.channels + c;
+ }
+
+ __device__ static int output_index(int n, int c, int y, int x, const conv_params & params) {
+ return n * (params.channels * params.out_w * params.out_h) + y * (params.out_w * params.channels) +
+ x * params.channels + c;
+ }
+
+ __device__ static void unpack_indices(int global_idx, const conv_params & params, int & n, int & c, int & out_y,
+ int & out_x) {
+ c = global_idx % params.channels;
+ out_x = (global_idx / params.channels) % params.out_w;
+ out_y = (global_idx / (params.channels * params.out_w)) % params.out_h;
+ n = global_idx / (params.channels * params.out_w * params.out_h);
+ }
+};
+
+template <typename T, typename Layout>
+__global__ void conv2d_dw_kernel(const T * __restrict__ input, const T * __restrict__ kernel, T * __restrict__ output,
+ const int in_w, const int in_h, const int out_w, const int out_h,
+ const int kernel_w, const int kernel_h, const int stride_x, const int stride_y,
+ const int padding_x, const int padding_y, const int dilation_x, const int dilation_y,
+ const int channels, const int batches) {
+ const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ const int total_elements = batches * channels * out_h * out_w;
+
+ if (global_idx >= total_elements) {
+ return;
+ }
+
+ conv_params params = { in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x,
+ stride_y, padding_x, padding_y, dilation_x, dilation_y, channels, batches };
+
+ int batch_idx, channel_idx, out_y_idx, out_x_idx;
+ Layout::unpack_indices(global_idx, params, batch_idx, channel_idx, out_y_idx, out_x_idx);
+
+ T accumulator = 0;
+ kernel_bounds bounds = calculate_kernel_bounds(out_x_idx, out_y_idx, params);
+
+ for (int kern_y = bounds.y_min; kern_y < bounds.y_max; ++kern_y) {
+ int in_y_idx = calculate_input_coord(out_y_idx, kern_y, params.stride_y, params.dilation_y, params.padding_y);
+
+ for (int kern_x = bounds.x_min; kern_x < bounds.x_max; ++kern_x) {
+ int in_x_idx = calculate_input_coord(out_x_idx, kern_x, params.stride_x, params.dilation_x, params.padding_x);
+
+ const T input_val = input[Layout::input_index(batch_idx, channel_idx, in_y_idx, in_x_idx, params)];
+ const T kernel_val = kernel[Layout::kernel_index(channel_idx, kern_y, kern_x, params)];
+
+ accumulator += input_val * kernel_val;
+ }
+ }
+
+ output[Layout::output_index(batch_idx, channel_idx, out_y_idx, out_x_idx, params)] = accumulator;
+}
+
+void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * input = dst->src[1];
+
+ GGML_ASSERT(kernel->type == GGML_TYPE_F32 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+ const float * w_d = (const float *) kernel->data;
+ const float * x_d = (const float *) input->data;
+ float * y_d = (float *) dst->data;
+
+ const int32_t * p = (const int32_t *) dst->op_params;
+ const int stride_x = p[0];
+ const int stride_y = p[1];
+ const int padding_x = p[2];
+ const int padding_y = p[3];
+ const int dilation_x = p[4];
+ const int dilation_y = p[5];
+
+ const int in_w = input->ne[0];
+ const int in_h = input->ne[1];
+ const int kernel_w = kernel->ne[0];
+ const int kernel_h = kernel->ne[1];
+ const int out_w = dst->ne[0];
+ const int out_h = dst->ne[1];
+ const int channels = dst->ne[2];
+ const int batches = dst->ne[3];
+
+ cudaStream_t st = ctx.stream();
+
+ const int total = batches * channels * out_h * out_w;
+ const int blocks = (total + CUDA_CONV2D_DW_BLOCK_SIZE - 1) / CUDA_CONV2D_DW_BLOCK_SIZE;
+
+ if (ggml_is_contiguous(input)) {
+ conv2d_dw_kernel<float, whcn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
+ x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
+ dilation_x, dilation_y, channels, batches);
+ } else if (ggml_is_contiguous_channels(input)) {
+ conv2d_dw_kernel<float, cwhn_layout><<<blocks, CUDA_CONV2D_DW_BLOCK_SIZE, 0, st>>>(
+ x_d, w_d, y_d, in_w, in_h, out_w, out_h, kernel_w, kernel_h, stride_x, stride_y, padding_x, padding_y,
+ dilation_x, dilation_y, channels, batches);
+ } else {
+ GGML_ABORT("Unsupported memory layout for conv_2d_dw");
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh
new file mode 100644
index 0000000..b5d5a69
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-dw.cuh
@@ -0,0 +1,5 @@
+#pragma once
+#include "common.cuh"
+
+#define CUDA_CONV2D_DW_BLOCK_SIZE 256
+void ggml_cuda_op_conv2d_dw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu
new file mode 100644
index 0000000..03224e4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cu
@@ -0,0 +1,91 @@
+#include <algorithm>
+
+#include "conv2d-transpose.cuh"
+#include "ggml.h"
+
+__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel,
+ float * __restrict__ output, const int in_w, const int in_h, const int out_w,
+ const int out_h, const int kernel_w, const int kernel_h, const int stride,
+ const int c_in, const int c_out, const int batches) {
+ const int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ const int total_elements = out_w * out_h * c_out * batches;
+
+ if (global_idx >= total_elements) {
+ return;
+ }
+
+ const int out_x_idx = global_idx % out_w;
+ const int out_y_idx = (global_idx / out_w) % out_h;
+ const int c_idx = (global_idx / (out_w * out_h)) % c_out;
+ const int n_idx = global_idx / (out_w * out_h * c_out);
+
+ float accumulator = 0;
+ // For each output idx, find the inputs that contribute to it by checking stride alignment and bounds
+
+ for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) {
+ for (int kh = 0; kh < kernel_h; ++kh) {
+ int in_y = out_y_idx - kh;
+ if (in_y < 0 || in_y % stride) continue;
+ in_y /= stride;
+ if (in_y >= in_h) continue;
+
+ for (int kw = 0; kw < kernel_w; ++kw) {
+ int in_x = out_x_idx - kw;
+ if (in_x < 0 || in_x % stride) continue;
+ in_x /= stride;
+ if (in_x >= in_w) continue;
+
+ const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x;
+ const int kernel_idx =
+ (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw;
+
+ float input_val = input[input_idx];
+ half kern_val = kernel[kernel_idx];
+
+ accumulator += input_val * (float) kern_val;
+ }
+ }
+ }
+
+ output[(out_w * out_h * c_out) * n_idx + (out_w * out_h) * c_idx + (out_w) *out_y_idx + out_x_idx] = accumulator;
+}
+
+//input is (W, H, C_in, N), Kernel is (W, H, C_out, C_in)
+void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * input = dst->src[1];
+
+ GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+
+ const float * input_data = (const float *) input->data;
+ float * output_data = (float *) dst->data;
+ const half * kernel_data = (const half *) kernel->data;
+
+ const int input_w = input->ne[0];
+ const int input_h = input->ne[1];
+ const int output_w = dst->ne[0];
+ const int output_h = dst->ne[1];
+ const int channels_in = input->ne[2];
+ const int channels_out = kernel->ne[2];
+ const int kernel_w = kernel->ne[0];
+ const int kernel_h = kernel->ne[1];
+ const int stride = dst->op_params[0];
+ const int batches = input->ne[3];
+
+ GGML_ASSERT(channels_in == kernel->ne[3]);
+ GGML_ASSERT(stride > 0);
+
+ cudaStream_t st = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(input));
+ GGML_ASSERT(ggml_is_contiguous(kernel));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ const int total = (output_w * output_h * channels_out * batches);
+ const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE;
+
+ conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>(
+ input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride,
+ channels_in, channels_out, batches);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh
new file mode 100644
index 0000000..c9430b2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv2d-transpose.cuh
@@ -0,0 +1,4 @@
+#include "common.cuh"
+
+#define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256
+void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d.cu b/llama.cpp/ggml/src/ggml-cuda/conv2d.cu
new file mode 100644
index 0000000..142dd66
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv2d.cu
@@ -0,0 +1,166 @@
+#include "conv2d.cuh"
+#include "convert.cuh"
+
+struct conv_params {
+ const int64_t IW, IH;
+ const int64_t OW, OH;
+ const int64_t KW, KH;
+ const int64_t ST_X, ST_Y;
+ const int64_t PD_X, PD_Y;
+ const int64_t DL_X, DL_Y;
+ const int64_t IC, OC;
+ const int64_t B;
+ const int64_t TOTAL;
+};
+
+struct kernel_bounds {
+ int64_t y_min, y_max;
+ int64_t x_min, x_max;
+};
+
+__device__ __forceinline__ int64_t max64(int64_t a, int64_t b) {
+ return (a > b) ? a : b;
+}
+
+__device__ __forceinline__ int64_t min64(int64_t a, int64_t b) {
+ return (a < b) ? a : b;
+}
+
+__device__ __forceinline__ kernel_bounds calculate_kernel_bounds(int64_t out_x, int64_t out_y, const conv_params & P) {
+ kernel_bounds bounds;
+ bounds.y_min = max64(0, (P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.y_max = min64(P.KH, (P.IH + P.PD_Y - out_y * P.ST_Y + P.DL_Y - 1) / P.DL_Y);
+ bounds.x_min = max64(0, (P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ bounds.x_max = min64(P.KW, (P.IW + P.PD_X - out_x * P.ST_X + P.DL_X - 1) / P.DL_X);
+ return bounds;
+}
+
+__device__ __forceinline__ int calculate_input_coord(int64_t out_coord,
+ int64_t kern_coord,
+ int64_t stride,
+ int64_t dilation,
+ int64_t padding) {
+ return out_coord * stride + kern_coord * dilation - padding;
+}
+
+struct whcn_layout {
+ __device__ static int64_t input_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
+ }
+
+ __device__ static int64_t kernel_index(int64_t c_out, int64_t c_in, int64_t ky, int64_t kx, const conv_params & P) {
+ return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
+ }
+
+ __device__ static int64_t output_index(int64_t n, int64_t c, int64_t y, int64_t x, const conv_params & P) {
+ return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
+ }
+
+ __device__ static void unpack_indices(int64_t global_idx,
+ const conv_params & P,
+ int64_t & n,
+ int64_t & c,
+ int64_t & out_y,
+ int64_t & out_x) {
+ out_x = global_idx % P.OW;
+ out_y = (global_idx / P.OW) % P.OH;
+ c = (global_idx / (P.OW * P.OH)) % P.OC;
+ n = global_idx / (P.OW * P.OH * P.OC);
+ }
+};
+
+template <typename T, typename Layout>
+static __global__ void conv2d_kernel(const float * __restrict__ input,
+ const T * __restrict__ kernel,
+ float * __restrict__ output,
+ const conv_params P) {
+ const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (global_idx >= P.TOTAL) {
+ return;
+ }
+
+ int64_t n, c_out, out_y, out_x;
+ Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
+
+ float acc = 0.0f;
+
+ for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
+ kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
+
+ for (int64_t ky = bounds.y_min; ky < bounds.y_max; ++ky) {
+ const int64_t in_y = calculate_input_coord(out_y, ky, P.ST_Y, P.DL_Y, P.PD_Y);
+
+ for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
+ const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
+
+ const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
+ const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
+ acc += (input_val * ggml_cuda_cast<float>(kernel_val));
+ }
+ }
+ }
+
+ // [N, OC, OH, OW]
+ output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
+}
+
+template <typename T>
+static void conv2d_cuda(const float * X_D, const T * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ const int blocks = (P.TOTAL + CUDA_CONV2D_BLOCK_SIZE - 1) / CUDA_CONV2D_BLOCK_SIZE;
+ conv2d_kernel<T, whcn_layout><<<blocks, CUDA_CONV2D_BLOCK_SIZE, 0, st>>>(X_D, K_D, Y_D, P);
+}
+
+static void conv2d_cuda_f16(const float * X_D, const half * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda<half>(X_D, K_D, Y_D, P, st);
+}
+
+static void conv2d_cuda_f32(const float * X_D, const float * K_D, float * Y_D, const conv_params P, cudaStream_t st) {
+ conv2d_cuda<float>(X_D, K_D, Y_D, P, st);
+}
+
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * kernel = dst->src[0];
+ const ggml_tensor * input = dst->src[1];
+ float * K_D = (float *) kernel->data;
+ const float * X_D = (const float *) input->data;
+ float * Y_D = (float *) dst->data;
+
+ GGML_ASSERT(ggml_is_contiguous(kernel));
+ GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32);
+
+ // same number of input channels
+ GGML_ASSERT(input->ne[2] == kernel->ne[2]);
+
+ cudaStream_t st = ctx.stream();
+
+ const int32_t * p = (const int32_t *) dst->op_params;
+ const int ST_X = p[0]; // stride_x
+ const int ST_Y = p[1]; // stride_y
+ const int PD_X = p[2]; // padding_x
+ const int PD_Y = p[3]; // padding_y
+ const int DL_X = p[4]; // dilation_x
+ const int DL_Y = p[5]; // dilation_y
+
+ // No cwhn
+ GGML_ASSERT(p[6] == false);
+
+ const int IW = input->ne[0]; // input_w
+ const int IH = input->ne[1]; // input_h
+ const int OW = dst->ne[0]; // output_w
+ const int OH = dst->ne[1]; // output_h
+ const int KW = kernel->ne[0]; // kernel_w
+ const int KH = kernel->ne[1]; // kernel_h
+ const int IC = input->ne[2]; // input_channels
+ const int OC = kernel->ne[3]; // ouptut_chanles
+ const int B = input->ne[3]; // n_batches
+
+ const int64_t total = B * OC * OH * OW;
+ conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X, PD_Y, DL_X, DL_Y, IC, OC, B, total };
+
+ if (kernel->type == GGML_TYPE_F16) {
+ conv2d_cuda_f16(X_D, (half *) K_D, Y_D, params, st);
+ } else {
+ conv2d_cuda_f32(X_D, K_D, Y_D, params, st);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh b/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh
new file mode 100644
index 0000000..ce4802c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/conv2d.cuh
@@ -0,0 +1,5 @@
+#pragma once
+#include "common.cuh"
+
+#define CUDA_CONV2D_BLOCK_SIZE 256
+void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/convert.cu b/llama.cpp/ggml/src/ggml-cuda/convert.cu
new file mode 100644
index 0000000..ba3d4ee
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/convert.cu
@@ -0,0 +1,825 @@
+#include "convert.cuh"
+#include "dequantize.cuh"
+
+#include <cstdint>
+
+#define CUDA_Q8_0_NE_ALIGN 2048
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t s01, const int64_t s02, const int64_t s03) {
+ const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int64_t i01 = blockIdx.y;
+ const int64_t i02 = blockIdx.z % ne02;
+ const int64_t i03 = blockIdx.z / ne02;
+
+ const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+
+ const int64_t ib = ibx0 + i00/qk; // block index
+ const int64_t iqs = (i00%qk)/qr; // quant index
+ const int64_t iybs = i00 - i00%qk; // y block start index
+ const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ float2 v;
+ dequantize_kernel(vx, ib, iqs, v);
+
+ const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
+ y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
+ y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+}
+
+template <bool need_check>
+static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+ constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
+
+ const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
+ const int * x0 = ((int *) vx) + blockIdx.x * nint;
+ half2 * y2 = (half2 *) (y + i0);
+
+ __shared__ int vals[nint];
+
+#pragma unroll
+ for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
+ if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
+ break;
+ }
+
+ const int ix = ix0 + threadIdx.x;
+ vals[ix] = x0[ix];
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
+ if (need_check && i0 + iy + 2*threadIdx.x >= k) {
+ return;
+ }
+
+ const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
+ const half d = *b0;
+ const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
+
+ y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
+ }
+#else
+ GGML_UNUSED_VARS(vx, y, k);
+ NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q4_0 * x = (const block_q4_0 *)vx + ib;
+ const float d = __half2float(x->d);
+ const float dm = -8*d;
+
+ const uint8_t * q = x->qs + 4*il;
+
+ for (int l = 0; l < 4; ++l) {
+ y[l+ 0] = d * (q[l] & 0xF) + dm;
+ y[l+16] = d * (q[l] >> 4) + dm;
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t ib = 8*i + ir;
+ if (ib >= nb32) {
+ return;
+ }
+
+ dst_t * y = yy + 256*i + 32*ir + 4*il;
+
+ const block_q4_1 * x = (const block_q4_1 *)vx + ib;
+ const float2 d = __half22float2(x->dm);
+
+ const uint8_t * q = x->qs + 4*il;
+
+ for (int l = 0; l < 4; ++l) {
+ y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
+ y[l+16] = d.x * (q[l] >> 4) + d.y;
+ }
+}
+
+//================================== k-quants
+
+template<typename dst_t>
+static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_q2_K * x = (const block_q2_K *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t n = tid/32;
+ const int64_t l = tid - 32*n;
+ const int64_t is = 8*n + l/16;
+
+ const uint8_t q = x[i].qs[32*n + l];
+ dst_t * y = yy + i*QK_K + 128*n;
+
+ float dall = __low2half(x[i].dm);
+ float dmin = __high2half(x[i].dm);
+ y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
+ y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
+ y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
+ y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_q3_K * x = (const block_q3_K *) vx;
+
+ const int64_t r = threadIdx.x/4;
+ const int64_t tid = r/2;
+ const int64_t is0 = r%2;
+ const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
+ const int64_t n = tid / 4;
+ const int64_t j = tid - 4*n;
+
+ uint8_t m = 1 << (4*n + j);
+ int64_t is = 8*n + 2*j + is0;
+ int shift = 2*j;
+
+ int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
+ is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
+ is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
+ (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
+ float d_all = x[i].d;
+ float dl = d_all * (us - 32);
+
+ dst_t * y = yy + i*QK_K + 128*n + 32*j;
+ const uint8_t * q = x[i].qs + 32*n;
+ const uint8_t * hm = x[i].hmask;
+
+ for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
+}
+
+static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
+ if (j < 4) {
+ d = q[j] & 63; m = q[j + 4] & 63;
+ } else {
+ d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
+ m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q4_K * x = (const block_q4_K *) vx;
+
+ const int64_t i = blockIdx.x;
+
+ // assume 32 threads
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8;
+ const int64_t ir = tid%8;
+ const int64_t is = 2*il;
+ const int64_t n = 4;
+
+ dst_t * y = yy + i*QK_K + 64*il + n*ir;
+
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
+
+ const uint8_t * q = x[i].qs + 32*il + n*ir;
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
+ const float d1 = dall * sc; const float m1 = dmin * m;
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
+ const float d2 = dall * sc; const float m2 = dmin * m;
+ for (int l = 0; l < n; ++l) {
+ y[l + 0] = d1 * (q[l] & 0xF) - m1;
+ y[l +32] = d2 * (q[l] >> 4) - m2;
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q5_K * x = (const block_q5_K *) vx;
+
+ const int64_t i = blockIdx.x;
+
+ // assume 64 threads - this is very slightly better than the one below
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/16; // il is in 0...3
+ const int64_t ir = tid%16; // ir is in 0...15
+ const int64_t is = 2*il; // is is in 0...6
+
+ dst_t * y = yy + i*QK_K + 64*il + 2*ir;
+
+ const float dall = __low2half(x[i].dm);
+ const float dmin = __high2half(x[i].dm);
+
+ const uint8_t * ql = x[i].qs + 32*il + 2*ir;
+ const uint8_t * qh = x[i].qh + 2*ir;
+
+ uint8_t sc, m;
+ get_scale_min_k4(is + 0, x[i].scales, sc, m);
+ const float d1 = dall * sc; const float m1 = dmin * m;
+ get_scale_min_k4(is + 1, x[i].scales, sc, m);
+ const float d2 = dall * sc; const float m2 = dmin * m;
+
+ uint8_t hm = 1 << (2*il);
+ y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
+ y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
+ hm <<= 1;
+ y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
+ y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const block_q6_K * x = (const block_q6_K *) vx;
+
+ const int64_t i = blockIdx.x;
+
+ // assume 64 threads - this is very slightly better than the one below
+ const int64_t tid = threadIdx.x;
+ const int64_t ip = tid/32; // ip is 0 or 1
+ const int64_t il = tid - 32*ip; // 0...32
+ const int64_t is = 8*ip + il/16;
+
+ dst_t * y = yy + i*QK_K + 128*ip + il;
+
+ const float d = x[i].d;
+
+ const uint8_t * ql = x[i].ql + 64*ip + il;
+ const uint8_t qh = x[i].qh[32*ip + il];
+ const int8_t * sc = x[i].scales + is;
+
+ y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
+ y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
+ y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
+ y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * q2 = x[i].qs + 4*ib;
+ const uint8_t * aux8 = (const uint8_t *)q2;
+ const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_xs * x = (const block_iq2_xs *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * q2 = x[i].qs + 4*ib;
+ const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+ const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq2_s * x = (const block_iq2_s *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
+ const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
+ const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
+ for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * q3 = x[i].qs + 8*ib;
+ const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
+ const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
+ const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
+ const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq3_s * x = (const block_iq3_s *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint8_t * qs = x[i].qs + 8*ib;
+ const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
+ const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
+ const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
+ const uint8_t signs = x[i].signs[4*ib + il];
+ for (int j = 0; j < 4; ++j) {
+ y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq1_s * x = (const block_iq1_s *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
+ const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+ for (int j = 0; j < 8; ++j) {
+ y[j] = d * (q[j] + delta);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq1_m * x = (const block_iq1_m *) vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 8*il;
+ const uint16_t * sc = (const uint16_t *)x[i].scales;
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
+ const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
+ const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
+ uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
+ grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
+ grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
+ grid32[0] &= 0x0f0f0f0f;
+ for (int j = 0; j < 8; ++j) {
+ y[j] = d * (q[j] + delta);
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ const float d = (float)x[ib].d;
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+ const int64_t i = blockIdx.x;
+ const block_iq4_xs * x = (const block_iq4_xs *)vx;
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
+ const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
+ y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
+ }
+}
+
+template<typename dst_t>
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
+ y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
+ }
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cuda(const void * vx, dst_t * y,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
+ const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
+ dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
+ (vx, y, ne00, ne01, ne02, s01, s02, s03);
+}
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
+ dequantize_block_cuda<qk, qr, dequantize_kernel, dst_t>(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream);
+}
+
+static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
+ if (k % CUDA_Q8_0_NE_ALIGN == 0) {
+ const bool need_check = false;
+ dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
+ } else {
+ const bool need_check = true;
+ dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
+ }
+}
+
+template<typename dst_t>
+static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb32 = k / 32;
+ const int nb = (k + 255) / 256;
+ dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
+}
+
+template<typename dst_t>
+static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = k / QK_K;
+ dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template<typename dst_t>
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
+}
+
+template <typename src_t, typename dst_t>
+static __global__ void convert_unary(
+ const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t s01, const int64_t s02, const int64_t s03) {
+ const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int64_t i01 = blockIdx.y;
+ const int64_t i02 = blockIdx.z % ne02;
+ const int64_t i03 = blockIdx.z / ne02;
+
+ const src_t * x = (const src_t *) vx;
+
+ const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
+ const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
+ y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary_cuda(const void * vx, dst_t * y,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
+ const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
+ convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
+ (vx, y, ne00, ne01, ne02, s01, s02, s03);
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream);
+}
+
+to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F32:
+ return convert_unary_cont_cuda<float>;
+ case GGML_TYPE_F16:
+ return convert_unary_cont_cuda<half>;
+ default:
+ return nullptr;
+ }
+}
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_row_q4_0_cuda;
+ case GGML_TYPE_Q4_1:
+ return dequantize_row_q4_1_cuda;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
+ return dequantize_block_q8_0_f16_cuda;
+ }
+ return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_cuda;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_cuda;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_cuda;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_XS:
+ return dequantize_row_iq2_xs_cuda;
+ case GGML_TYPE_IQ2_S:
+ return dequantize_row_iq2_s_cuda;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_cuda;
+ case GGML_TYPE_IQ1_S:
+ return dequantize_row_iq1_s_cuda;
+ case GGML_TYPE_IQ1_M:
+ return dequantize_row_iq1_m_cuda;
+ case GGML_TYPE_IQ4_NL:
+ return dequantize_row_iq4_nl_cuda;
+ case GGML_TYPE_IQ4_XS:
+ return dequantize_row_iq4_xs_cuda;
+ case GGML_TYPE_IQ3_S:
+ return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
+ case GGML_TYPE_F32:
+ return convert_unary_cont_cuda<float>;
+ case GGML_TYPE_BF16:
+ return convert_unary_cont_cuda<nv_bfloat16>;
+ default:
+ return nullptr;
+ }
+}
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ return dequantize_row_q4_0_cuda;
+ case GGML_TYPE_Q4_1:
+ return dequantize_row_q4_1_cuda;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cont_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cont_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_cont_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_Q2_K:
+ return dequantize_row_q2_K_cuda;
+ case GGML_TYPE_Q3_K:
+ return dequantize_row_q3_K_cuda;
+ case GGML_TYPE_Q4_K:
+ return dequantize_row_q4_K_cuda;
+ case GGML_TYPE_Q5_K:
+ return dequantize_row_q5_K_cuda;
+ case GGML_TYPE_Q6_K:
+ return dequantize_row_q6_K_cuda;
+ case GGML_TYPE_IQ2_XXS:
+ return dequantize_row_iq2_xxs_cuda;
+ case GGML_TYPE_IQ2_XS:
+ return dequantize_row_iq2_xs_cuda;
+ case GGML_TYPE_IQ2_S:
+ return dequantize_row_iq2_s_cuda;
+ case GGML_TYPE_IQ3_XXS:
+ return dequantize_row_iq3_xxs_cuda;
+ case GGML_TYPE_IQ1_S:
+ return dequantize_row_iq1_s_cuda;
+ case GGML_TYPE_IQ1_M:
+ return dequantize_row_iq1_m_cuda;
+ case GGML_TYPE_IQ4_NL:
+ return dequantize_row_iq4_nl_cuda;
+ case GGML_TYPE_IQ4_XS:
+ return dequantize_row_iq4_xs_cuda;
+ case GGML_TYPE_IQ3_S:
+ return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
+ case GGML_TYPE_F16:
+ return convert_unary_cont_cuda<half>;
+ case GGML_TYPE_BF16:
+ return convert_unary_cont_cuda<nv_bfloat16>;
+ default:
+ return nullptr;
+ }
+}
+
+to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F32:
+ return convert_unary_cuda<float>;
+ case GGML_TYPE_Q4_0:
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+ case GGML_TYPE_Q4_1:
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_BF16:
+ return convert_unary_cuda<nv_bfloat16>;
+ default:
+ return nullptr;
+ }
+}
+
+to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F32:
+ return convert_unary_cuda<float, nv_bfloat16>;
+ case GGML_TYPE_Q4_0:
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+ case GGML_TYPE_Q4_1:
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_F16:
+ return convert_unary_cuda<half, nv_bfloat16>;
+ default:
+ return nullptr;
+ }
+}
+
+to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_F16:
+ return convert_unary_cuda<half, float>;
+ case GGML_TYPE_Q4_0:
+ return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
+ case GGML_TYPE_Q4_1:
+ return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
+ case GGML_TYPE_BF16:
+ return convert_unary_cuda<nv_bfloat16, float>;
+ default:
+ return nullptr;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/convert.cuh b/llama.cpp/ggml/src/ggml-cuda/convert.cuh
new file mode 100644
index 0000000..09f9a33
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/convert.cuh
@@ -0,0 +1,56 @@
+#pragma once
+#include "common.cuh"
+
+#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
+
+template<typename T>
+using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream);
+
+typedef to_t_cuda_t<float> to_fp32_cuda_t;
+typedef to_t_cuda_t<half> to_fp16_cuda_t;
+typedef to_t_cuda_t<nv_bfloat16> to_bf16_cuda_t;
+
+to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
+
+to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
+
+to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
+
+// TODO more general support for non-contiguous inputs
+
+template<typename T>
+using to_t_nc_cuda_t = void (*)(const void * x, T * y,
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
+ int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
+
+typedef to_t_nc_cuda_t<float> to_fp32_nc_cuda_t;
+typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
+typedef to_t_nc_cuda_t<nv_bfloat16> to_bf16_nc_cuda_t;
+
+to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type);
+to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
+to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type);
+
+template<typename dst_t, typename src_t>
+ __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) {
+ if constexpr (std::is_same_v<dst_t, src_t>) {
+ return x;
+ } else if constexpr(std::is_same_v<dst_t, nv_bfloat16>) {
+ return __float2bfloat16(float(x));
+ } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
+ return __bfloat162float(x);
+ } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
+ return __float22half2_rn(x);
+ } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
+ // bypass compile error on cuda 12.0.1
+#ifdef GGML_USE_HIP
+ return __float22bfloat162_rn(x);
+#else
+ return {x.x, x.y};
+#endif // GGML_USE_HIP
+ } else if constexpr(std::is_same_v<dst_t, int32_t>) {
+ return int32_t(x);
+ } else {
+ return float(x);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/count-equal.cu b/llama.cpp/ggml/src/ggml-cuda/count-equal.cu
new file mode 100644
index 0000000..0889811
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/count-equal.cu
@@ -0,0 +1,64 @@
+#include "common.cuh"
+#include "count-equal.cuh"
+
+#include <cstdint>
+
+template <typename T>
+static __global__ void count_equal(const T * __restrict__ x, const T * __restrict__ y, int64_t * __restrict__ dst, const int64_t dk, const int64_t k) {
+ const int64_t i0 = (int64_t) blockIdx.x*dk;
+ const int64_t i1 = min(i0 + dk, k);
+
+ int nequal = 0;
+
+ for (int64_t i = i0 + threadIdx.x; i < i1; i += WARP_SIZE) {
+ const T xi = x[i];
+ const T yi = y[i];
+ nequal += xi == yi;
+ }
+
+ nequal = warp_reduce_sum(nequal);
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ atomicAdd((int *) dst, nequal);
+}
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == src1->type);
+ GGML_ASSERT( dst->type == GGML_TYPE_I64);
+
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ int64_t * dst_d = (int64_t *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+
+ const int64_t ne = ggml_nelements(src0);
+ GGML_ASSERT(ne < (1 << 30) && "atomicAdd implementation only supports int");
+ const int64_t dne = GGML_PAD((ne + 4*nsm - 1) / (4*nsm), CUDA_COUNT_EQUAL_CHUNK_SIZE);
+
+ CUDA_CHECK(cudaMemsetAsync(dst_d, 0, ggml_nbytes(dst), stream));
+
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
+ const dim3 blocks_num(std::min((int64_t)4*nsm, (ne + CUDA_COUNT_EQUAL_CHUNK_SIZE - 1)/CUDA_COUNT_EQUAL_CHUNK_SIZE), 1, 1);
+
+ switch (src0->type) {
+ case GGML_TYPE_I32: {
+ const int * src0_d = (const int *) src0->data;
+ const int * src1_d = (const int *) src1->data;
+ count_equal<<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_d, dne, ne);
+ } break;
+ default:
+ GGML_ASSERT(false);
+ break;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/count-equal.cuh b/llama.cpp/ggml/src/ggml-cuda/count-equal.cuh
new file mode 100644
index 0000000..8467da7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/count-equal.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_COUNT_EQUAL_CHUNK_SIZE 128
+
+void ggml_cuda_count_equal(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/cp-async.cuh b/llama.cpp/ggml/src/ggml-cuda/cp-async.cuh
new file mode 100644
index 0000000..63d0c48
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cp-async.cuh
@@ -0,0 +1,57 @@
+// Simplified API for asynchronous data loading.
+
+#include "common.cuh"
+
+
+static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
+#ifdef CP_ASYNC_AVAILABLE
+ return __cvta_generic_to_shared(generic_ptr);
+#else
+ GGML_UNUSED(generic_ptr);
+ NO_DEVICE_CODE;
+ return 0;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// Copies data from global to shared memory, cg == cache global.
+// Both the src and dst pointers must be aligned to 16 bit.
+// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.
+// Generic pointers can be converted to 32 bit shared memory pointers using __cvta_generic_to_shared.
+// Only the 16 bit copy is exposed because 4 and 8 bit copies did not yield performance improvements.
+template <int preload>
+static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, const void * src) {
+ static_assert(preload == 0 || preload == 64 || preload == 128 || preload == 256, "bad preload");
+#ifdef CP_ASYNC_AVAILABLE
+#if CUDART_VERSION >= 11040
+ if (preload == 256) {
+ asm volatile("cp.async.cg.shared.global.L2::256B [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ } else if (preload == 128) {
+ asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ } else if (preload == 64) {
+ asm volatile("cp.async.cg.shared.global.L2::64B [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ } else
+#endif // CUDART_VERSION >= 11040
+ {
+ asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
+ : : "r"(dst), "l"(src));
+ }
+#else
+ GGML_UNUSED(dst);
+ GGML_UNUSED(src);
+ NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// Makes each thread wait until its asynchronous data copies are done.
+// This does NOT provide any additional synchronization.
+// In particular, when copying data with multiple warps a call to __syncthreads will be needed.
+static __device__ __forceinline__ void cp_async_wait_all() {
+#ifdef CP_ASYNC_AVAILABLE
+ asm volatile("cp.async.wait_all;");
+#else
+ NO_DEVICE_CODE;
+#endif // CP_ASYNC_AVAILABLE
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh b/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh
new file mode 100644
index 0000000..7697c29
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cpy-utils.cuh
@@ -0,0 +1,217 @@
+#pragma once
+
+#include "ggml-common.h"
+#include "convert.cuh"
+
+static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
+static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK4_0; ++j) {
+ const float v = x[j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -8;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y->d = d;
+
+ for (int j = 0; j < QK4_0/2; ++j) {
+ const float x0 = x[0 + j]*id;
+ const float x1 = x[QK4_0/2 + j]*id;
+
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
+
+ y->qs[j] = xi0;
+ y->qs[j] |= xi1 << 4;
+ }
+}
+
+static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
+ float vmin = FLT_MAX;
+ float vmax = -FLT_MAX;
+
+ for (int j = 0; j < QK4_1; ++j) {
+ const float v = x[j];
+ if (v < vmin) vmin = v;
+ if (v > vmax) vmax = v;
+ }
+
+ const float d = (vmax - vmin) / ((1 << 4) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y->dm.x = d;
+ y->dm.y = vmin;
+
+ for (int j = 0; j < QK4_1/2; ++j) {
+ const float x0 = (x[0 + j] - vmin)*id;
+ const float x1 = (x[QK4_1/2 + j] - vmin)*id;
+
+ const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
+ const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
+
+ y->qs[j] = xi0;
+ y->qs[j] |= xi1 << 4;
+ }
+}
+
+static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK5_0; ++j) {
+ const float v = x[j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ vmax = v;
+ }
+ }
+
+ const float d = vmax / -16;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y->d = d;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_0/2; ++j) {
+ const float x0 = x[0 + j]*id;
+ const float x1 = x[QK5_0/2 + j]*id;
+
+ const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
+ const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
+
+ y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+ }
+ memcpy(y->qh, &qh, sizeof(qh));
+}
+
+static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
+ float min = x[0];
+ float max = x[0];
+
+ for (int j = 1; j < QK5_1; ++j) {
+ const float v = x[j];
+ min = v < min ? v : min;
+ max = v > max ? v : max;
+ }
+
+ const float d = (max - min) / 31;
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y->dm.x = d;
+ y->dm.y = min;
+
+ uint32_t qh = 0;
+ for (int j = 0; j < QK5_1/2; ++j) {
+ const float x0 = (x[0 + j] - min)*id;
+ const float x1 = (x[QK5_1/2 + j] - min)*id;
+
+ const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+ const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+ y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+ qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+ qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+ }
+ memcpy(y->qh, &qh, sizeof(qh));
+}
+
+static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = x[j];
+ amax = fmaxf(amax, fabsf(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ y->d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = x[j]*id;
+ y->qs[j] = roundf(x0);
+ }
+}
+
+static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
+ float amax = 0.0f;
+ float vmax = 0.0f;
+
+ for (int j = 0; j < QK4_NL; ++j) {
+ const float v = x[j];
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ vmax = v;
+ }
+ }
+
+ float d = vmax / kvalues_iq4nl[0];
+ const float id = d ? 1.0f/d : 0.0f;
+
+ float sumqx = 0, sumq2 = 0;
+ for (int j = 0; j < QK4_NL/2; ++j) {
+ const float x0 = x[0 + j]*id;
+ const float x1 = x[QK4_NL/2 + j]*id;
+ const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
+ const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
+ y->qs[j] = xi0 | (xi1 << 4);
+ const float v0 = kvalues_iq4nl[xi0];
+ const float v1 = kvalues_iq4nl[xi1];
+ const float w0 = x[0 + j]*x[0 + j];
+ const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
+ sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
+ sumq2 += w0*v0*v0 + w1*v1*v1;
+ }
+
+ y->d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
+// Wrapper functions for cpy.cu compatibility
+static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+ quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+ quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
+ quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
+ quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+ quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
+}
+
+static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
+ quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
+}
+
+template<typename src_t, typename dst_t>
+static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
+ *(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/cpy.cu b/llama.cpp/ggml/src/ggml-cuda/cpy.cu
new file mode 100644
index 0000000..ee84303
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cpy.cu
@@ -0,0 +1,555 @@
+#include "cpy.cuh"
+#include "dequantize.cuh"
+#include "cpy-utils.cuh"
+#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
+#include "ggml-musa/mudnn.cuh"
+#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
+
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+
+const int CUDA_CPY_TILE_DIM_2D = 32; // 2D tile dimension for transposed blocks
+const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
+const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
+
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+ const int64_t nb12, const int64_t nb13) {
+ const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+ // then combine those indices with the corresponding byte offsets to get the total offsets
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+
+ cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+template <typename T>
+static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+ const int64_t nb12, const int64_t nb13) {
+
+ const T* src = reinterpret_cast<const T*>(cx);
+ T* dst = reinterpret_cast<T*>(cdst);
+
+ const int64_t nmat = ne / (ne00 * ne01);
+ const int64_t n = ne00 * ne01;
+
+ const int x = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.x;
+ const int y = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
+ const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset
+ const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y;
+
+ __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1];
+
+#pragma unroll
+ for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) {
+
+ const unsigned int imat = blockIdx.z * CUDA_CPY_BLOCK_NM + i;
+ if (imat >= nmat)
+ break;
+
+#pragma unroll
+ for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
+ if(x < ne01 && y + j < ne00){
+ const int row = threadIdx.y+j;
+ const int col = threadIdx.x * sizeof(float)/sizeof(T);
+ T *tile2 = reinterpret_cast<T*>(tile[row]);
+ tile2[col] = src[imat*n + (y+j)*ne01 + x];
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) {
+ if (ty + j < ne01 && tx < ne00) {
+ const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T);
+ const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]);
+ dst[imat*n + (ty+j)*ne00 + tx] = tile2[col];
+ }
+ }
+ }
+
+ GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
+ nb12, nb13);
+}
+
+static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
+ float * cdstf = (float *)(cdsti);
+
+#pragma unroll
+ for (int j = 0; j < QK8_0; j += 2) {
+ float2 dq;
+ dequantize_q8_0(cxi, 0, j, dq);
+ *(cdstf + j) = dq.x;
+ *(cdstf + j + 1) = dq.y;
+ }
+}
+
+template<dequantize_kernel_t dequant, int qk>
+static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
+ float * cdstf = (float *)(cdsti);
+
+#pragma unroll
+ for (int j = 0; j < qk/2; j++) {
+ float2 dq;
+ dequant(cxi, 0, j, dq);
+ *(cdstf + j) = dq.x;
+ *(cdstf + j + qk/2) = dq.y;
+ }
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+ const int64_t nb12, const int64_t nb13) {
+ const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+ cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
+template <cpy_kernel_t cpy_blck, int qk>
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+ const int64_t nb12, const int64_t nb13) {
+ const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+
+ cpy_blck(cx + x_offset, cdst + dst_offset);
+}
+
+template<typename src_t, typename dst_t>
+static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
+ const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const src_t * x = (const src_t *) cx;
+ dst_t * dst = (dst_t *) cdst;
+
+ dst[i] = ggml_cuda_cast<dst_t>(x[i]);
+}
+
+template<typename src_t, typename dst_t>
+static void ggml_cpy_scalar_contiguous_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+cudaStream_t stream) {
+
+ const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne);
+}
+
+template<typename src_t, typename dst_t, bool transposed = false>
+static void ggml_cpy_scalar_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ if (transposed) {
+ GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
+ int64_t ne00n, ne01n, ne02n;
+ if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
+ ne00n = ne00;
+ ne01n = ne01;
+ ne02n = ne02;
+ } else {
+ ne00n = ne00;
+ ne01n = ne01*ne02;
+ ne02n = 1;
+ }
+
+ int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
+ int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
+ int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
+ GGML_ASSERT(grid_x < UINT_MAX);
+ GGML_ASSERT(grid_y < USHRT_MAX);
+ GGML_ASSERT(grid_z < USHRT_MAX);
+ dim3 dimGrid(grid_x, grid_y, grid_z);
+ dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
+ cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
+ (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ } else {
+ const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+ }
+}
+
+static void ggml_cpy_f32_q8_0_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK8_0 == 0);
+ const int64_t num_blocks = ne / QK8_0;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_q8_0_f32_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ const int64_t num_blocks = ne;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_0_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_0 == 0);
+ const int64_t num_blocks = ne / QK4_0;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_q4_0_f32_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+ const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
+ cudaStream_t stream) {
+ const int64_t num_blocks = ne;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q4_1_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_1 == 0);
+ const int64_t num_blocks = ne / QK4_1;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_q4_1_f32_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+ const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
+ cudaStream_t stream) {
+ const int64_t num_blocks = ne;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_0_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK5_0 == 0);
+ const int64_t num_blocks = ne / QK5_0;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_q5_0_f32_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+ const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
+ cudaStream_t stream) {
+ const int64_t num_blocks = ne;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_q5_1_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK5_1 == 0);
+ const int64_t num_blocks = ne / QK5_1;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_q5_1_f32_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+ const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
+ cudaStream_t stream) {
+ const int64_t num_blocks = ne;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
+ cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
+ ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+static void ggml_cpy_f32_iq4_nl_cuda(
+ const char * cx, char * cdst, const int64_t ne,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+ const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
+
+ GGML_ASSERT(ne % QK4_NL == 0);
+ const int64_t num_blocks = ne / QK4_NL;
+ GGML_ASSERT(num_blocks < UINT_MAX);
+ cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
+ const int64_t ne = ggml_nelements(src0);
+ GGML_ASSERT(ne == ggml_nelements(src1));
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+
+ //GGML_ASSERT(src0->ne[3] == 1);
+
+ const int64_t nb00 = src0->nb[0];
+ const int64_t nb01 = src0->nb[1];
+ const int64_t nb02 = src0->nb[2];
+ const int64_t nb03 = src0->nb[3];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+
+ //GGML_ASSERT(src1->ne[3] == 1);
+
+ const int64_t nb10 = src1->nb[0];
+ const int64_t nb11 = src1->nb[1];
+ const int64_t nb12 = src1->nb[2];
+ const int64_t nb13 = src1->nb[3];
+
+ cudaStream_t main_stream = ctx.stream();
+
+ char * src0_ddc = (char *) src0->data;
+ char * src1_ddc = (char *) src1->data;
+
+ const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
+ const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
+ src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);
+
+ if (src0->type == src1->type && contiguous_srcs) {
+ GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
+#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
+ if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
+ CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0));
+ } else
+#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
+ {
+ CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
+ }
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ if (can_be_transposed) {
+ ggml_cpy_scalar_cuda<float, float, true>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<float, float>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<float, nv_bfloat16>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<float, half>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<float, half>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+ ggml_cpy_f32_q8_0_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_q8_0_f32_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+ ggml_cpy_f32_q4_0_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_q4_0_f32_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+ ggml_cpy_f32_q4_1_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_q4_1_f32_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+ ggml_cpy_f32_q5_0_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_q5_0_f32_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+ ggml_cpy_f32_iq4_nl_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+ ggml_cpy_f32_q5_1_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
+ ggml_cpy_q5_1_f32_cuda
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ if (can_be_transposed) {
+ ggml_cpy_scalar_cuda<half, half, true>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<half, half>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<half, nv_bfloat16>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<half, float>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<half, float>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
+ if (can_be_transposed) {
+ ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<nv_bfloat16, half>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<nv_bfloat16, float>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ if (can_be_transposed) {
+ ggml_cpy_scalar_cuda<int32_t, int32_t, true>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<int32_t, int32_t>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<float, int32_t>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<float, int32_t>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
+ if (contiguous_srcs) {
+ ggml_cpy_scalar_contiguous_cuda<int32_t, float>
+ (src0_ddc, src1_ddc, ne, main_stream);
+ } else {
+ ggml_cpy_scalar_cuda<int32_t, float>
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+ }
+ } else {
+ GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
+ }
+}
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ ggml_cuda_cpy(ctx, src0, dst);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/cpy.cuh b/llama.cpp/ggml/src/ggml-cuda/cpy.cuh
new file mode 100644
index 0000000..a7a87d8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cpy.cuh
@@ -0,0 +1,7 @@
+#include "common.cuh"
+
+#define CUDA_CPY_BLOCK_SIZE 64
+
+void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
+
+void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu
new file mode 100644
index 0000000..0c8b081
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cu
@@ -0,0 +1,177 @@
+#include "common.cuh"
+#include "cross-entropy-loss.cuh"
+#include "sum.cuh"
+
+#include <cmath>
+#include <cstdint>
+
+template <bool use_shared>
+static __global__ void cross_entropy_loss_f32(
+ const float * __restrict__ logits, const float * __restrict__ labels, float * __restrict__ dst, const int nclasses, const int k) {
+ extern __shared__ float tmp[];
+
+ logits += int64_t(blockIdx.x)*nclasses;
+ labels += int64_t(blockIdx.x)*nclasses;
+
+ // Find maximum for softmax:
+ float max_logit = -INFINITY;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float val = logits[i];
+ max_logit = fmaxf(max_logit, val);
+
+ if (use_shared) {
+ tmp[i] = val;
+ }
+ }
+ max_logit = warp_reduce_max(max_logit);
+
+ // Calculate log(softmax(logits)) which is just logits - max:
+ float sum = 0.0f;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float logit_i = use_shared ? tmp[i] : logits[i];
+ sum += expf(logit_i - max_logit);
+ }
+ sum = warp_reduce_sum(sum);
+ sum = logf(sum);
+
+ // log(exp(logits - max) / sum) = (logits - max) - log(sum)
+ float loss = 0.0f;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float logit_i = use_shared ? tmp[i] : logits[i];
+ loss += (logit_i - max_logit - sum) * labels[i];
+ }
+ loss = -warp_reduce_sum(loss) / (float)k;
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ dst[blockIdx.x] = loss;
+}
+
+template <bool use_shared>
+static __global__ void cross_entropy_loss_back_f32(
+ const float * __restrict__ grad, const float * __restrict__ logits, const float * __restrict__ labels,
+ float * __restrict__ dst, const int nclasses) {
+ extern __shared__ float tmp[];
+
+ logits += int64_t(blockIdx.x)*nclasses;
+ labels += int64_t(blockIdx.x)*nclasses;
+ dst += int64_t(blockIdx.x)*nclasses;
+
+ float maxval = -INFINITY;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float val = logits[i];
+ maxval = fmaxf(maxval, val);
+
+ if (use_shared) {
+ tmp[i] = val;
+ }
+ }
+ maxval = warp_reduce_max(maxval);
+
+ float sum = 0.0f;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float val = expf((use_shared ? tmp[i] : logits[i]) - maxval);
+ sum += val;
+
+ if (use_shared) {
+ tmp[i] = val;
+ } else {
+ dst[i] = val;
+ }
+ }
+ sum = warp_reduce_sum(sum);
+ const float sm_scale = 1.0f/sum;
+
+ const float d_by_nrows = *grad/gridDim.x;
+ for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) {
+ const float val = use_shared ? tmp[i] : dst[i];
+ dst[i] = (val*sm_scale - labels[i])*d_by_nrows;
+ }
+}
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ ggml_cuda_pool & pool = ctx.pool();
+ cudaStream_t stream = ctx.stream();
+
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
+ const dim3 blocks_num(nrows, 1, 1);
+ const size_t nbytes_shared = ne00*sizeof(float);
+
+ const int id = ggml_cuda_get_device();
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+ ggml_cuda_pool_alloc<float> dst_tmp(pool, blocks_num.x);
+
+ if (nbytes_shared <= smpbo) {
+ CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_f32<true>), smpbo);
+ cross_entropy_loss_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+ } else {
+ cross_entropy_loss_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(src0_d, src1_d, dst_tmp.ptr, ne00, nrows);
+ }
+ CUDA_CHECK(cudaGetLastError());
+
+ // Combine results from individual blocks:
+ sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream);
+}
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * grad = dst->src[0];
+ const ggml_tensor * src0f = dst->src[1];
+ const ggml_tensor * src1f = dst->src[2];
+
+ GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1f->type == GGML_TYPE_F32);
+ GGML_ASSERT( grad->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_scalar(grad));
+ GGML_ASSERT(ggml_is_contiguous(src0f));
+ GGML_ASSERT(ggml_is_contiguous(src1f));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_are_same_shape(src0f, src1f));
+ GGML_ASSERT(ggml_are_same_shape(src0f, dst));
+
+ const int64_t ne00 = src0f->ne[0];
+ const int64_t nrows = ggml_nrows(src0f);
+
+ const float * grad_d = (const float *) grad->data;
+ const float * src0f_d = (const float *) src0f->data;
+ const float * src1f_d = (const float *) src1f->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ const dim3 blocks_dim(WARP_SIZE, 1, 1);
+ const dim3 blocks_num(nrows, 1, 1);
+ const size_t nbytes_shared = ne00*sizeof(float);
+
+ const int id = ggml_cuda_get_device();
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+ if (nbytes_shared <= smpbo) {
+ CUDA_SET_SHARED_MEMORY_LIMIT((cross_entropy_loss_back_f32<true>), smpbo);
+ cross_entropy_loss_back_f32<true><<<blocks_num, blocks_dim, nbytes_shared, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+ } else {
+ cross_entropy_loss_back_f32<false><<<blocks_num, blocks_dim, 0, stream>>>(grad_d, src0f_d, src1f_d, dst_d, ne00);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cuh
new file mode 100644
index 0000000..9ec7152
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cross-entropy-loss.cuh
@@ -0,0 +1,7 @@
+#include "common.cuh"
+
+#define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256
+
+void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/cumsum.cu b/llama.cpp/ggml/src/ggml-cuda/cumsum.cu
new file mode 100644
index 0000000..def9c32
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cumsum.cu
@@ -0,0 +1,307 @@
+#include <algorithm>
+#include "cumsum.cuh"
+#include "convert.cuh"
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
+
+#ifdef GGML_CUDA_USE_CUB
+# include <cub/cub.cuh>
+#endif // GGML_CUDA_USE_CUB
+
+template<typename T, int BLOCK_SIZE>
+static __global__ void cumsum_cub_kernel(
+ const T * __restrict__ src,
+ T * __restrict__ dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t s1, const int64_t s2, const int64_t s3) {
+#ifdef GGML_CUDA_USE_CUB
+ using BlockScanT = cub::BlockScan<T, BLOCK_SIZE>;
+
+ __shared__ typename BlockScanT::TempStorage temp_storage;
+ __shared__ T block_carry;
+
+ const int tid = threadIdx.x;
+ constexpr int UNROLL_FACTOR = 4;
+ constexpr int TILE_SIZE = BLOCK_SIZE * UNROLL_FACTOR;
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i3 = blockIdx.z;
+
+ if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
+ return;
+ }
+
+ const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
+ T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
+
+ if (tid == 0) {
+ block_carry = 0;
+ }
+ __syncthreads();
+
+ for (int64_t start = 0; start < ne00; start += TILE_SIZE) {
+ T items[UNROLL_FACTOR];
+ T thread_sum = T(0);
+
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
+ T val = (idx < ne00) ? src_row[idx] : T(0);
+ thread_sum += val;
+ items[i] = thread_sum;
+ }
+
+ // Block-wide scan on thread sums
+ T thread_prefix;
+ T block_total;
+ BlockScanT(temp_storage).InclusiveSum(thread_sum, thread_prefix, block_total);
+ __syncthreads();
+
+ // Add offset to each item and store
+ T thread_offset = thread_prefix - thread_sum + block_carry;
+#pragma unroll
+ for (int i = 0; i < UNROLL_FACTOR; i++) {
+ int64_t idx = start + tid * UNROLL_FACTOR + i;
+ if (idx < ne00) {
+ dst_row[idx] = items[i] + thread_offset;
+ }
+ }
+
+ __syncthreads();
+
+ // Update carry for next tile
+ if (tid == 0) {
+ block_carry += block_total;
+ }
+ }
+#else
+ NO_DEVICE_CODE;
+#endif // GGML_CUDA_USE_CUB
+}
+
+// Fallback kernel implementation
+template<typename T>
+static __global__ void cumsum_kernel(
+ const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t s00, const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t s0, const int64_t s1, const int64_t s2, const int64_t s3) {
+
+ GGML_UNUSED_VARS(s00, s0);
+
+ const int tid = threadIdx.x;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ const int lane = tid % warp_size;
+ const int warp = tid / warp_size;
+ const int warps_per_block = blockDim.x / warp_size;
+
+ extern __shared__ float smem[];
+ float * s_vals = smem;
+ float * s_warp_sums = smem + blockDim.x;
+ float * s_carry = smem + blockDim.x + warps_per_block;
+ float * s_chunk_total = s_carry + 1;
+
+ // Initialize carry
+ if (tid == 0) {
+ *s_carry = 0.0f;
+ }
+ __syncthreads();
+
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i1 = blockIdx.x;
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ const T * src_row = src + i1 * s01 + i2 * s02 + i3 * s03;
+ T * dst_row = dst + i1 * s1 + i2 * s2 + i3 * s3;
+
+ // register blocking: process 4 elements per thread to hide latency
+ // and reduce synchronization overhead
+ constexpr int num_unroll = 4;
+ T temp[num_unroll];
+
+ for (int64_t i = 0; i < ne00; i += num_unroll * blockDim.x) {
+ int64_t idx = i + tid * num_unroll;
+
+ // thread local sequential scan
+ temp[0] = (idx < ne00 ? src_row[idx] : T(0));
+#pragma unroll
+ for (int64_t j = 1; j < num_unroll; j++) {
+ temp[j] = temp[j - 1];
+ if (idx + j < ne00) {
+ temp[j] += src_row[idx + j];
+ } else {
+ temp[j] += 0;
+ }
+ }
+
+ // last emenent is sum of all values assigned to thread
+ float val = (idx < ne00) ? ggml_cuda_cast<float, T>(temp[num_unroll - 1]) : 0.0f;
+
+ // Warp inclusive scan
+ val = warp_prefix_inclusive_sum<T, warp_size>(val);
+ s_vals[tid] = val;
+
+ if (lane == warp_size - 1) {
+ s_warp_sums[warp] = val;
+ }
+ __syncthreads();
+
+ // Exclusive scan of warp sums (warp 0 only)
+ if (warp == 0) {
+ float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
+ float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
+ if (tid < warps_per_block) {
+ s_warp_sums[tid] = inc - w; // exclusive sum
+ }
+ if (tid == warps_per_block - 1) {
+ *s_chunk_total = inc; // total sum of this chunk
+ }
+ }
+ __syncthreads();
+
+ // write back results
+ float carry = *s_carry;
+ // calculate sum offset for this thread
+ float final_val_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1];
+
+#pragma unroll
+ for (int32_t j = 0; j < num_unroll; j++) {
+ if (idx + j < ne00) {
+ dst_row[idx + j] = temp[j] + ggml_cuda_cast<T, float>(final_val_offset);
+ }
+ }
+
+ __syncthreads();
+
+ // Update carry for next chunk
+ if (tid == 0) {
+ *s_carry += *s_chunk_total;
+ }
+ }
+}
+
+#ifdef GGML_CUDA_USE_CUB
+template <typename T>
+static void cumsum_cub(ggml_cuda_pool & pool,
+ const T * src,
+ T * dst,
+ int64_t ne,
+ cudaStream_t stream) {
+ size_t tmp_size = 0;
+
+ // Query how much temp storage CUDA UnBound (CUB) needs
+ cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
+ tmp_size, // reference to size (will be set by CUB)
+ src, // input pointer
+ dst, // output pointer
+ ne, // number of elements
+ stream // CUDA stream to use
+ );
+
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
+
+ // Perform the inclusive scan
+ cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
+}
+#endif // GGML_CUDA_USE_CUB
+
+template<typename T>
+static void cumsum_cuda(
+ [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
+ cudaStream_t stream) {
+
+ const size_t type_size = sizeof(T);
+ bool use_cub = false;
+#ifdef GGML_CUDA_USE_CUB
+ // Check if we can use CUB (data must be contiguous along innermost dimension)
+ const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size);
+
+ if (is_contiguous) {
+ use_cub = true;
+ const int64_t nrows = ne01 * ne02 * ne03;
+ // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
+ // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
+ if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
+ for (int i=0; i<nrows; i++) {
+ cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
+ }
+ return;
+ }
+ }
+#endif // GGML_CUDA_USE_CUB
+ dim3 grid_dims(ne01, ne02, ne03);
+ const auto &info = ggml_cuda_info().devices[ggml_cuda_get_device()];
+ const int warp_size = info.warp_size;
+ const int num_warps = (ne00 + warp_size - 1) / warp_size;
+ int block_size = num_warps * warp_size;
+ block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
+ dim3 block_dims(block_size, 1, 1);
+ const int warps_per_block = block_size / warp_size;
+ const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);
+
+ if (use_cub && ne00 >= 1024) {
+ cumsum_cub_kernel<T, CUDA_CUMSUM_BLOCK_SIZE><<<grid_dims, CUDA_CUMSUM_BLOCK_SIZE, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ } else {
+ cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ }
+}
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == dst->type);
+ switch(src0->type) {
+ case GGML_TYPE_F32:
+ {
+ cumsum_cuda(
+ ctx, (const float *)src0->data, (float *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ stream
+ );
+ } break;
+ // We do not support those on CPU for now anyway, so comment them out because they cause errors on some CI platforms
+ /*case GGML_TYPE_F16:
+ {
+ cumsum_cuda(
+ (const half *)src0->data, (half *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ stream
+ );
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ cumsum_cuda(
+ (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ stream
+ );
+ } break;*/
+ default:
+ GGML_ABORT("fatal error");
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/cumsum.cuh b/llama.cpp/ggml/src/ggml-cuda/cumsum.cuh
new file mode 100644
index 0000000..782d1d9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/cumsum.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_CUMSUM_BLOCK_SIZE 256
+
+void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh b/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh
new file mode 100644
index 0000000..e060fb2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/dequantize.cuh
@@ -0,0 +1,77 @@
+#include "common.cuh"
+
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const float d = x[ib].d;
+
+ const int vui = x[ib].qs[iqs];
+
+ v.x = vui & 0xF;
+ v.y = vui >> 4;
+
+ v.x = (v.x - 8.0f) * d;
+ v.y = (v.y - 8.0f) * d;
+}
+
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const float2 dm = __half22float2(x[ib].dm);
+
+ const int vui = x[ib].qs[iqs];
+
+ v.x = vui & 0xF;
+ v.y = vui >> 4;
+
+ v.x = (v.x * dm.x) + dm.y;
+ v.y = (v.y * dm.x) + dm.y;
+}
+
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const float d = x[ib].d;
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+ v.x = (v.x - 16.0f) * d;
+ v.y = (v.y - 16.0f) * d;
+}
+
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int64_t ib, const int iqs, float2 & v){
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const float2 dm = __half22float2(x[ib].dm);
+
+ uint32_t qh;
+ memcpy(&qh, x[ib].qh, sizeof(qh));
+
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
+
+ v.x = (v.x * dm.x) + dm.y;
+ v.y = (v.y * dm.x) + dm.y;
+}
+
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, float2 & v){
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const float d = x[ib].d;
+
+ v.x = x[ib].qs[iqs + 0];
+ v.y = x[ib].qs[iqs + 1];
+
+ v.x *= d;
+ v.y *= d;
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/diag.cu b/llama.cpp/ggml/src/ggml-cuda/diag.cu
new file mode 100644
index 0000000..5cea210
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/diag.cu
@@ -0,0 +1,77 @@
+#include "convert.cuh"
+#include "diag.cuh"
+#include "ggml.h"
+
+template <typename T>
+static __global__ void diag_kernel(T * __restrict__ dst,
+ const T * __restrict__ src,
+ const int64_t ne0,
+ const int64_t ne1,
+ const int64_t ne2,
+ const int64_t ne3,
+ const int64_t total_elements) {
+ const int64_t global_idx = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (global_idx >= total_elements) {
+ return;
+ }
+
+ const int64_t i0 = global_idx % ne0;
+ const int64_t i1 = (global_idx / ne0) % ne1;
+ const int64_t i2 = (global_idx / (ne0 * ne1)) % ne2;
+ const int64_t i3 = global_idx / (ne0 * ne1 * ne2);
+
+ const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0;
+
+ if (i0 == i1) {
+ const int64_t batch_idx = i3 * ne2 + i2;
+ const int64_t src_idx = batch_idx * ne0 + i0;
+ dst[dst_idx] = src[src_idx];
+ } else {
+ dst[dst_idx] = ggml_cuda_cast<T>(0);
+ }
+ GGML_UNUSED_VARS(ne3);
+}
+
+void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ void * dst_d = dst->data;
+ const void * src0_d = src0->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
+
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne1 = dst->ne[1];
+ const int64_t ne2 = dst->ne[2];
+ const int64_t ne3 = dst->ne[3];
+
+ GGML_ASSERT(ne00 == ne0);
+ GGML_ASSERT(ne01 == 1);
+ GGML_ASSERT(ne02 == ne2);
+ GGML_ASSERT(ne03 == ne3);
+
+ const int64_t n_elems = ggml_nelements(dst);
+ const int64_t num_blocks = (n_elems + CUDA_DIAG_BLOCK_SIZE - 1) / CUDA_DIAG_BLOCK_SIZE;
+
+ switch (dst->type) {
+ case GGML_TYPE_F32:
+ diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((float *) dst_d, (const float *) src0_d, ne0,
+ ne1, ne2, ne3, n_elems);
+ break;
+ case GGML_TYPE_F16:
+ diag_kernel<<<num_blocks, CUDA_DIAG_BLOCK_SIZE, 0, stream>>>((half *) dst_d, (const half *) src0_d, ne0,
+ ne1, ne2, ne3, n_elems);
+ break;
+ default:
+ GGML_ABORT("unsupported type");
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/diag.cuh b/llama.cpp/ggml/src/ggml-cuda/diag.cuh
new file mode 100644
index 0000000..7d73e6a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/diag.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_DIAG_BLOCK_SIZE 256
+
+void ggml_cuda_op_diag(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/diagmask.cu b/llama.cpp/ggml/src/ggml-cuda/diagmask.cu
new file mode 100644
index 0000000..4b713ba
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/diagmask.cu
@@ -0,0 +1,40 @@
+#include "diagmask.cuh"
+
+static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
+ const int col = blockDim.y*blockIdx.y + threadIdx.y;
+ const int row = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (col >= ncols) {
+ return;
+ }
+
+ const int i = row*ncols + col;
+ //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+ //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+ dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
+}
+
+static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
+ const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
+ const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
+ const dim3 block_nums(nrows_x, block_num_x, 1);
+ diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
+}
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int nrows0 = ggml_nrows(src0);
+
+ const int n_past = ((int32_t *) dst->op_params)[0];
+
+ diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/diagmask.cuh b/llama.cpp/ggml/src/ggml-cuda/diagmask.cuh
new file mode 100644
index 0000000..6cdbef1
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/diagmask.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
+
+void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh
new file mode 100644
index 0000000..b6a7460
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn-common.cuh
@@ -0,0 +1,1036 @@
+#pragma once
+
+#include "common.cuh"
+#include "convert.cuh"
+#include "vecdotq.cuh"
+
+#include <cstdint>
+
+#define FATTN_KQ_STRIDE 256
+#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+
+// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
+// by the VKQ accumulators is effectively being shifted up by a factor of 2.
+// This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
+// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
+// Still, the value range should be shifted as much as necessary but as little as possible.
+// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
+#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
+
+typedef void (* fattn_kernel_t)(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ const char * __restrict__ sinks,
+ const int * __restrict__ KV_max,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33);
+
+typedef float (*vec_dot_KQ_t)(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template <int D, int nthreads>
+static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
+
+ const half2 * K_h2 = (const half2 *) K_c;
+ GGML_UNUSED(Q_q8);
+ GGML_UNUSED(Q_ds_v);
+
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
+ __align__(16) half2 tmp[cpy_ne];
+ ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
+#pragma unroll
+ for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#else
+ ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#endif // V_DOT2_F32_F16_AVAILABLE
+ }
+ }
+
+ return sum;
+}
+
+template<int D, int nthreads>
+static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI4_0;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
+ sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x - (8/QI8_1)*Q_ds.y);
+ }
+
+ return sum;
+}
+
+template<int D, int nthreads>
+static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_1(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI4_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+ const float2 K_dm = __half22float2(K_q4_1[ib].dm);
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
+
+ sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
+ }
+
+ return sum;
+}
+
+template<int D, int nthreads>
+static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_0(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI5_0;
+ const int iqs8 = k_KQ % QI8_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+
+ {
+ int vh;
+ ggml_cuda_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
+ vh >>= iqs8 * QI5_0;
+
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
+ }
+
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
+
+ sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x - (16/QI8_1)*Q_ds.y);
+ }
+
+ return sum;
+}
+
+template<int D, int nthreads>
+static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q5_1(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI5_1;
+ const int iqs8 = k_KQ % QI8_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_cuda_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+
+ {
+ int vh;
+ ggml_cuda_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
+ vh >>= iqs8 * QI5_0;
+
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
+ }
+
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_cuda_dp4a(v, u, 0);
+
+ const float2 K_dm = __half22float2(K_q5_1[ib].dm);
+ const float2 Q_ds = ((const float2 *) Q_ds_v)[k_KQ_0/nthreads];
+
+ sum += K_dm.x*Q_ds.x*sumi + K_dm.y*Q_ds.y/QI8_1;
+ }
+
+ return sum;
+}
+
+template <int D, int nthreads>
+static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
+
+ const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ = k_KQ_0 + (nthreads == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads);
+
+ const int ib = k_KQ / QI8_0;
+ const int iqs = k_KQ % QI8_0;
+
+ int v;
+ ggml_cuda_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
+
+ const float2 * Q_ds = (const float2 *) Q_ds_v;
+ const float Q_d = Q_ds[k_KQ_0/nthreads].x;
+
+ sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
+ }
+
+ return sum;
+}
+
+template <typename Tds, int ni>
+static __device__ __forceinline__ void quantize_q8_1_to_shared(
+ const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) {
+
+ float vals[sizeof(int)] = {0.0f};
+#pragma unroll
+ for (int l = 0; l < int(sizeof(int)); ++l) {
+ vals[l] = (ni == WARP_SIZE || threadIdx.x < ni) ? scale * x[4*threadIdx.x + l] : 0.0f;
+ }
+
+ float amax = fabsf(vals[0]);
+ float sum = vals[0];
+#pragma unroll
+ for (int l = 1; l < int(sizeof(int)); ++l) {
+ amax = fmaxf(amax, fabsf(vals[l]));
+ sum += vals[l];
+ }
+#pragma unroll
+ for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, 32));
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, 32);
+ }
+
+ const float d = amax / 127;
+ int q32 = 0;
+ int8_t * q8 = (int8_t *) &q32;
+
+ if (d != 0.0f) {
+#pragma unroll
+ for (int l = 0; l < int(sizeof(int)); ++l) {
+ q8[l] = roundf(vals[l] / d);
+ }
+ }
+
+ yq32[threadIdx.x] = q32;
+ if (threadIdx.x % QI8_1 == 0 && (ni == WARP_SIZE || threadIdx.x < ni)) {
+ if (std::is_same<Tds, half2>::value) {
+ ((half2 *) yds)[threadIdx.x/QI8_1] = make_half2(d, sum);
+ } else {
+ ((float2 *) yds)[threadIdx.x/QI8_1] = make_float2(d, sum);
+ }
+ }
+}
+
+typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
+
+template <typename T, int ne>
+static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ if constexpr (std::is_same_v<T, half>) {
+ ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0);
+ } else if constexpr (std::is_same_v<T, float>) {
+ static_assert(ne % 2 == 0, "bad ne");
+ __align__(16) half2 tmp[ne/2];
+ ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0);
+ float2 * dst_f2 = (float2 *) dst;
+#pragma unroll
+ for (int l = 0; l < ne/2; ++l) {
+ dst_f2[l] = __half22float2(tmp[l]);
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+}
+
+template <typename T, int ne>
+static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const int64_t ib = i0 / QK4_0;
+ const int iqs = i0 % (QK4_0/2);
+ const int shift = (i0 % QK4_0) / (QK4_0/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+ q = __vsubss4(q, 0x08080808);
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, half>) {
+ const half2 d = __half2half2(x[ib].d);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
+ }
+ } else
+#endif // FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, float>) {
+ const float d = x[ib].d;
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = d * q8[l];
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const int64_t ib = i0 / QK4_1;
+ const int iqs = i0 % (QK4_1/2);
+ const int shift = (i0 % QK4_1) / (QK4_1/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, half>) {
+ const half2 dm = x[ib].dm;
+ const half2 d = __half2half2( __low2half(dm));
+ const half2 m = __half2half2(__high2half(dm));
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
+ }
+ } else
+#endif // FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, float>) {
+ const float2 dm = __half22float2(x[ib].dm);
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = dm.x * q8[l] + dm.y;
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const int64_t ib = i0 / QK5_0;
+ const int idq = i0 % QK5_0;
+ const int iqs = i0 % (QK5_0/2);
+ const int shift = (i0 % QK5_0) / (QK5_0/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_cuda_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+
+ {
+ int qh;
+ ggml_cuda_memcpy_1<ne, 2>(&qh, x[ib].qh);
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+ }
+ }
+
+ q = __vsubss4(q, 0x10101010);
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, half>) {
+ const half2 d = __half2half2(x[ib].d);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]);
+ }
+ } else
+#endif // FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, float>) {
+ const float d = x[ib].d;
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = d * q8[l];
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const int64_t ib = i0 / QK5_1;
+ const int idq = i0 % QK5_1;
+ const int iqs = i0 % (QK5_1/2);
+ const int shift = (i0 % QK5_1) / (QK5_1/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_cuda_memcpy_1<ne>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+
+ {
+ int qh;
+ ggml_cuda_memcpy_1<ne>(&qh, x[ib].qh);
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+ }
+ }
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, half>) {
+ const half2 dm = x[ib].dm;
+ const half2 d = __half2half2( __low2half(dm));
+ const half2 m = __half2half2(__high2half(dm));
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((half2 *) dst)[l0/2] = d * make_half2(q8[l0 + 0], q8[l0 + 1]) + m;
+ }
+ } else
+#endif // FP16_AVAILABLE
+ if constexpr (std::is_same_v<T, float>) {
+ const float2 dm = __half22float2(x[ib].dm);
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = dm.x * q8[l] + dm.y;
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const int64_t ib = i0 / QK8_0;
+ const int iqs = i0 % QK8_0;
+
+ static_assert(ne % 2 == 0, "bad ne");
+ int8_t qs[ne];
+ ggml_cuda_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
+
+#ifdef FP16_AVAILABLE
+ if constexpr (std::is_same<T, half>::value) {
+ const half2 d = __half2half2(x[ib].d);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((half2 *) dst)[l0/2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
+ }
+ } else
+#endif // FP16_AVAILABLE
+ if constexpr (std::is_same<T, float>::value) {
+ const float d = x[ib].d;
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = d * qs[l];
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+}
+
+template <ggml_type type_K, int D, int nthreads>
+constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() {
+ if constexpr (type_K == GGML_TYPE_F16) {
+ return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
+ } else if constexpr (type_K == GGML_TYPE_Q4_0) {
+ return vec_dot_fattn_vec_KQ_q4_0<D, nthreads>;
+ } else if constexpr (type_K == GGML_TYPE_Q4_1) {
+ return vec_dot_fattn_vec_KQ_q4_1<D, nthreads>;
+ } else if constexpr (type_K == GGML_TYPE_Q5_0) {
+ return vec_dot_fattn_vec_KQ_q5_0<D, nthreads>;
+ } else if constexpr (type_K == GGML_TYPE_Q5_1) {
+ return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>;
+ } else if constexpr (type_K == GGML_TYPE_Q8_0) {
+ return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>;
+ } else {
+ static_assert(type_K == -1, "bad type");
+ return nullptr;
+ }
+}
+
+template <ggml_type type_V, typename T, int ne>
+constexpr __device__ dequantize_V_t get_dequantize_V() {
+ if constexpr (type_V == GGML_TYPE_F16) {
+ return dequantize_V_f16<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q4_0) {
+ return dequantize_V_q4_0<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q4_1) {
+ return dequantize_V_q4_1<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q5_0) {
+ return dequantize_V_q5_0<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q5_1) {
+ return dequantize_V_q5_1<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q8_0) {
+ return dequantize_V_q8_0<T, ne>;
+ } else {
+ static_assert(type_V == -1, "bad type");
+ return nullptr;
+ }
+}
+
+template <int ncols1>
+__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
+static __global__ void flash_attn_mask_to_KV_max(
+ const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
+ const int ne31 = gridDim.x;
+ const int tid = threadIdx.x;
+ const int sequence = blockIdx.y;
+ const int jt = blockIdx.x;
+
+ mask += sequence*s33 + jt*ncols1*s31;
+
+ __shared__ int buf_iw[WARP_SIZE];
+ if (tid < WARP_SIZE) {
+ buf_iw[tid] = 1;
+ }
+ __syncthreads();
+
+ int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
+ for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
+ int all_inf = 1;
+
+#pragma unroll
+ for (int j = 0; j < ncols1; ++j) {
+ const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
+ all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
+ }
+
+ all_inf = warp_reduce_all(all_inf);
+ if (tid % WARP_SIZE == 0) {
+ buf_iw[tid / WARP_SIZE] = all_inf;
+ }
+ __syncthreads();
+ all_inf = buf_iw[tid % WARP_SIZE];
+ __syncthreads();
+ all_inf = warp_reduce_all(all_inf);
+
+ if (!all_inf) {
+ break;
+ }
+ }
+
+ // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+ // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+ // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+ KV_max_sj += FATTN_KQ_STRIDE;
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ KV_max[sequence*ne31 + jt] = KV_max_sj;
+}
+
+template<int D, int ncols1, int ncols2> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_stream_k_fixup(
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
+ const int ne11, const int ne12, const int nbatch_fa) {
+ constexpr int ncols = ncols1*ncols2;
+
+ const int bidx0 = blockIdx.x;
+ const int j = blockIdx.y;
+ const int c = blockIdx.z;
+ const int jc = j*ncols2 + c;
+ const int tid = threadIdx.x;
+
+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
+
+ const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+ const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+ return;
+ }
+
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+ const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+ const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+ const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+ const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+ if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
+ return;
+ }
+
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
+
+ // Load the partial result that needs a fixup:
+ float dst_val = 0.0f;
+ float max_val = 0.0f;
+ float rowsum = 0.0f;
+ {
+ dst_val = *dst;
+
+ const float2 tmp = dst_fixup[bidx0*ncols + jc];
+ max_val = tmp.x;
+ rowsum = tmp.y;
+ }
+
+ // Iterate over previous blocks and compute the combined results.
+ // All CUDA blocks that get here must have a previous block that needs a fixup.
+ int bidx = bidx0 - 1;
+ int kbc_stop = kbc0;
+ while(true) {
+ const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+ if (kbc == kbc_stop) { // Did not have any data.
+ bidx--;
+ kbc_stop = kbc;
+ continue;
+ }
+
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+ const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
+
+ // Scale the current and new value accumulators depending on the max. values.
+ const float max_val_new = fmaxf(max_val, tmp.x);
+
+ const float diff_val = max_val - max_val_new;
+ const float diff_add = tmp.x - max_val_new;
+
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+
+ dst_val = scale_val*dst_val + scale_add*dst_add;
+ rowsum = scale_val*rowsum + scale_add*tmp.y;
+
+ max_val = max_val_new;
+
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+ break;
+ }
+ bidx--;
+ kbc_stop = kbc;
+ }
+
+ // Write back final result:
+ *dst = dst_val / rowsum;
+}
+
+template<int D> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_combine_results(
+ const float * __restrict__ VKQ_parts,
+ const float2 * __restrict__ VKQ_meta,
+ float * __restrict__ dst,
+ const int parallel_blocks) {
+ // Dimension 0: threadIdx.x
+ // Dimension 1: blockIdx.x
+ // Dimension 2: blockIdx.y
+ // Dimension 3: blockIdx.z
+ // Memory layout is permuted with [0, 2, 1, 3]
+
+ const int ne01 = gridDim.x;
+ const int ne02 = gridDim.y;
+
+ const int col = blockIdx.x;
+ const int head = blockIdx.y;
+ const int sequence = blockIdx.z;
+
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
+ dst += j_dst_unrolled * D;
+
+ const int tid = threadIdx.x;
+ __builtin_assume(tid < D);
+
+ extern __shared__ float2 meta[];
+ for (int i = tid; i < 2*parallel_blocks; i += D) {
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
+ }
+
+ __syncthreads();
+
+ float kqmax = meta[0].x;
+ for (int l = 1; l < parallel_blocks; ++l) {
+ kqmax = max(kqmax, meta[l].x);
+ }
+
+ float VKQ_numerator = 0.0f;
+ float VKQ_denominator = 0.0f;
+ for (int l = 0; l < parallel_blocks; ++l) {
+ const float KQ_max_scale = expf(meta[l].x - kqmax);
+
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
+ VKQ_denominator += KQ_max_scale * meta[l].y;
+ }
+
+ dst[tid] = VKQ_numerator / VKQ_denominator;
+}
+
+template <int DV, int ncols1, int ncols2>
+void launch_fattn(
+ ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
+ const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
+) {
+ constexpr int ncols = ncols1 * ncols2;
+
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+
+ const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
+
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
+
+ ggml_tensor * KQV = dst;
+
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
+ GGML_ASSERT(K->nb[0] == ggml_element_size(K));
+ GGML_ASSERT(V->nb[0] == ggml_element_size(V));
+
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+
+ ggml_cuda_pool & pool = ctx.pool();
+ cudaStream_t main_stream = ctx.stream();
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+
+ ggml_cuda_pool_alloc<half> K_f16(pool);
+ ggml_cuda_pool_alloc<half> V_f16(pool);
+ ggml_cuda_pool_alloc<int> KV_max(pool);
+ ggml_cuda_pool_alloc<float> dst_tmp(pool);
+ ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+ const char * K_data = (const char *) K->data;
+ size_t nb11 = K->nb[1];
+ size_t nb12 = K->nb[2];
+ size_t nb13 = K->nb[3];
+
+ const char * V_data = (const char *) V->data;
+ size_t nb21 = V->nb[1];
+ size_t nb22 = V->nb[2];
+ size_t nb23 = V->nb[3];
+
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
+ const size_t bs = ggml_blck_size(K->type);
+ const size_t ts = ggml_type_size(K->type);
+
+ K_f16.alloc(ggml_nelements(K));
+ if (ggml_is_contiguously_allocated(K)) {
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+
+ nb11 = nb11*bs*sizeof(half)/ts;
+ nb12 = nb12*bs*sizeof(half)/ts;
+ nb13 = nb13*bs*sizeof(half)/ts;
+ } else {
+ GGML_ASSERT(K->nb[0] == ts);
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type);
+ const int64_t s01 = nb11 / ts;
+ const int64_t s02 = nb12 / ts;
+ const int64_t s03 = nb13 / ts;
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
+
+ nb11 = K->ne[0] * sizeof(half);
+ nb12 = K->ne[1] * nb11;
+ nb13 = K->ne[2] * nb12;
+ }
+ K_data = (char *) K_f16.ptr;
+ }
+
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
+ if (V_is_K_view) {
+ V_data = K_data;
+ nb21 = nb11;
+ nb22 = nb12;
+ nb23 = nb13;
+ } else {
+ const size_t bs = ggml_blck_size(V->type);
+ const size_t ts = ggml_type_size(V->type);
+
+ V_f16.alloc(ggml_nelements(V));
+ if (ggml_is_contiguously_allocated(V)) {
+ to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+ V_data = (char *) V_f16.ptr;
+
+ nb21 = nb21*bs*sizeof(half)/ts;
+ nb22 = nb22*bs*sizeof(half)/ts;
+ nb23 = nb23*bs*sizeof(half)/ts;
+ } else {
+ GGML_ASSERT(V->nb[0] == ts);
+ to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
+ const int64_t s01 = nb21 / ts;
+ const int64_t s02 = nb22 / ts;
+ const int64_t s03 = nb23 / ts;
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+ nb21 = V->ne[0] * sizeof(half);
+ nb22 = V->ne[1] * nb21;
+ nb23 = V->ne[2] * nb22;
+ }
+ V_data = (char *) V_f16.ptr;
+ }
+ }
+
+ const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+ const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
+
+ // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
+ // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
+ // multiple sequences of possibly different lengths.
+ if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
+ const int s31 = mask->nb[1] / sizeof(half2);
+ const int s33 = mask->nb[3] / sizeof(half2);
+
+ const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
+ const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
+
+ const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
+ const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
+
+ KV_max.alloc(ne_KV_max);
+ flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
+ ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ const dim3 block_dim(warp_size, nwarps, 1);
+ int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
+ CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
+ GGML_ASSERT(max_blocks_per_sm > 0);
+ int parallel_blocks = max_blocks_per_sm;
+
+ dim3 blocks_num;
+ if (stream_k) {
+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+ const int max_blocks = max_blocks_per_sm*nsm;
+ const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
+ const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
+
+ const int nblocks_stream_k = max_blocks;
+
+ const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
+
+ blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+ blocks_num.y = 1;
+ blocks_num.z = 1;
+
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+ dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
+ }
+ } else {
+ const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
+
+ // parallel_blocks must not be larger than what the tensor size allows:
+ parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+
+ // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
+ // Test whether parallel_blocks can be set to a higher value for better efficiency.
+ const int blocks_per_wave = nsm * max_blocks_per_sm;
+ int nwaves_best = 0;
+ int efficiency_percent_best = 0;
+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
+ const int nblocks_total = ntiles_total * parallel_blocks_test;
+ const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
+ const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
+
+ // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
+ if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
+ break;
+ }
+
+ if (efficiency_percent > efficiency_percent_best) {
+ nwaves_best = nwaves;
+ efficiency_percent_best = efficiency_percent;
+ parallel_blocks = parallel_blocks_test;
+ }
+ }
+
+ blocks_num.x = ntiles_x;
+ blocks_num.y = parallel_blocks;
+ blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
+
+ if (parallel_blocks > 1) {
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+ }
+ }
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0.0f) {
+ scale /= logit_softcap;
+ }
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ // TODO other tensor dimensions after removal of WMMA kernel:
+ const uint3 ne01 = init_fastdiv_values(Q->ne[1]);
+
+ GGML_ASSERT(block_dim.x % warp_size == 0);
+ fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
+ (const char *) Q->data,
+ K_data,
+ V_data,
+ mask ? ((const char *) mask->data) : nullptr,
+ sinks ? ((const char *) sinks->data) : nullptr,
+ KV_max.ptr,
+ !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
+ scale, max_bias, m0, m1, n_head_log2, logit_softcap,
+ Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
+ K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0
+ );
+ CUDA_CHECK(cudaGetLastError());
+
+ if (stream_k) {
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+ const dim3 block_dim_combine(DV, 1, 1);
+ const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
+
+ flash_attn_stream_k_fixup<DV, ncols1, ncols2>
+ <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
+ }
+ } else if (parallel_blocks > 1) {
+ const dim3 block_dim_combine(DV, 1, 1);
+ const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
+ const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
+
+ flash_attn_combine_results<DV>
+ <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
+ (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
+ }
+ CUDA_CHECK(cudaGetLastError());
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh
new file mode 100644
index 0000000..0b8ef90
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -0,0 +1,1750 @@
+#include "common.cuh"
+#include "cp-async.cuh"
+#include "mma.cuh"
+#include "fattn-common.cuh"
+
+using namespace ggml_cuda_mma;
+
+// Config options for the MMA kernel.
+// Should not affect results, only speed/register pressure/shared memory use.
+struct fattn_mma_config {
+ int nthreads; // Number of threads per CUDA block.
+ int occupancy; // Targeted occupancy for the MMA kernel.
+ int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
+ int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
+ int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
+ int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
+ int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
+ bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
+
+ constexpr __host__ __device__ fattn_mma_config(
+ int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
+ nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
+ nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
+};
+
+#define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
+ if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
+ static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \
+ static_assert( (occupancy_) <= 8, "bad occupancy"); \
+ static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \
+ static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \
+ static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \
+ static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \
+ static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \
+ return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
+ } \
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
+
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
+}
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
+
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
+
+ // TODO tune specifically for Volta
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
+
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
+
+ // TODO tune specifically for RDNA
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+}
+
+static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+ if (ampere_mma_available(cc)) {
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+ }
+ if (turing_mma_available(cc)) {
+ return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+ }
+ if (amd_wmma_available(cc)) {
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
+ }
+ GGML_ASSERT(volta_mma_available(cc));
+ return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+}
+
+static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
+#if defined(AMPERE_MMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
+#elif defined(TURING_MMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+#elif defined(VOLTA_MMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
+#elif defined(AMD_WMMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
+#else
+ GGML_UNUSED_VARS(DKQ, DV, ncols);
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
+#endif // defined(AMPERE_MMA_AVAILABLE)
+}
+
+static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
+}
+
+static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
+}
+
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
+}
+
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
+}
+
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
+}
+
+static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
+}
+
+static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
+}
+
+static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
+}
+
+static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
+}
+
+static constexpr __device__ int get_cols_per_thread() {
+#if defined(AMD_WMMA_AVAILABLE)
+ return 1; // RDNA has a single column.
+#else
+ return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+#endif // defined(AMD_WMMA_AVAILABLE)
+}
+
+static __host__ int get_cols_per_warp(const int cc) {
+ if (turing_mma_available(cc) || amd_wmma_available(cc)) {
+ return 16;
+ } else {
+ // Volta
+ return 32;
+ }
+}
+
+// ------------------------------------------------------------------------------------------------------------------
+
+static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
+ return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
+}
+
+static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
+#ifdef CP_ASYNC_AVAILABLE
+ return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
+#else
+ GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
+ return 0;
+#endif // CP_ASYNC_AVAILABLE
+}
+
+// ------------------------------------------------------------------------------------------------------------------
+
+template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
+static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
+ // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
+ // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
+ if constexpr (use_cp_async) {
+ static_assert(!oob_check, "OOB check not compatible with cp_async");
+ constexpr int preload = 64;
+ constexpr int h2_per_chunk = 16/sizeof(half2);
+ const int chunks_per_row = D2 / h2_per_chunk;
+
+ const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
+
+ auto load = [&] __device__ (auto n) {
+ const int stride_k = WARP_SIZE >> n;
+ const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
+ const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
+ const int stride_i = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
+ break;
+ }
+
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
+ }
+ }
+ };
+ // 1: max 32*16=512 bytes, 256 half
+ // 2: max 16*16=256 bytes, 128 half
+ // 3: max 8*16=128 bytes, 64 half
+ // 4: max 4*16= 64 bytes, 32 half
+ // 5: max 2*16= 32 bytes, 16 half
+ // 6: max 1*16= 16 bytes, 8 half
+ ggml_cuda_unroll<6>{}(load);
+ } else {
+ // TODO use ggml_cuda_memcpy_1
+ auto load = [&] __device__ (const int n) {
+ const int stride_k = WARP_SIZE >> n;
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
+ const int k0_stop = D2 - D2 % (1*stride_k);
+ const int stride_i = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+ if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
+ break;
+ }
+
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
+ }
+ }
+ };
+ // 1: max 32* 4=128 bytes, 64 half
+ // 2: max 16* 4= 64 bytes, 32 half
+ // 3: max 8* 4= 32 bytes, 16 half
+ // 4: max 4* 4= 16 bytes, 8 half
+ ggml_cuda_unroll<4>{}(load);
+ }
+}
+
+template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
+static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
+ const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
+ const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
+ if constexpr (use_cp_async) {
+ static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
+ static_assert(!oob_check, "OOB check incompatible with cp_async");
+ constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
+ constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
+ constexpr int stride_j = nwarps * cols_per_warp;
+
+ const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
+
+#pragma unroll
+ for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
+ break;
+ }
+
+ const int i = 8 * (threadIdx.x % (nbatch_fa/8));
+
+ cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
+ }
+ } else if constexpr (oob_check) {
+#pragma unroll
+ for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
+ const int j_sram = j1 + threadIdx.y;
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
+ break;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
+ }
+ }
+ } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
+ constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+ constexpr int stride_j = nwarps * cols_per_warp;
+#pragma unroll
+ for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
+ break;
+ }
+
+ const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
+
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
+ }
+ } else {
+#pragma unroll
+ for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
+ const int j_sram = j1 + threadIdx.y;
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
+
+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
+ break;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
+ const int i = i0 + 2*threadIdx.x;
+
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
+ }
+ }
+ }
+}
+
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
+ bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
+ typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
+static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+ const float2 * const __restrict__ Q_f2,
+ const half2 * const __restrict__ K_h2,
+ const half2 * const __restrict__ V_h2,
+ const half * const __restrict__ mask_h,
+ float2 * const __restrict__ dstk,
+ float2 * const __restrict__ dstk_fixup,
+ const float scale,
+ const float slope,
+ const float logit_softcap,
+ const uint3 ne01,
+ const int ne02,
+ const int stride_K,
+ const int stride_V,
+ const int stride_mask,
+ half2 * const __restrict__ tile_Q,
+ half2 * const __restrict__ tile_K,
+ half2 * const __restrict__ tile_V,
+ half * const __restrict__ tile_mask,
+ T_B_KQ * const __restrict__ Q_B,
+ T_C_VKQ * const __restrict__ VKQ_C,
+ float * const __restrict__ KQ_max,
+ float * const __restrict__ KQ_rowsum,
+ const int jt,
+ const int kb0,
+ const int k_VKQ_sup) {
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int cols_per_warp = T_B_KQ::I;
+ constexpr int cols_per_thread = get_cols_per_thread();
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
+ constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
+ constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
+
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = nbatch_K2 + 4;
+
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
+
+ const int k_VKQ_0 = kb0 * nbatch_fa;
+#if defined(TURING_MMA_AVAILABLE)
+ T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
+#elif defined(AMD_WMMA_AVAILABLE)
+ T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
+#else // Volta
+ T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
+#endif // defined(TURING_MMA_AVAILABLE)
+
+ if constexpr (nstages > 1) {
+ static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
+ constexpr bool use_cp_async = true;
+ cp_async_wait_all();
+ __syncthreads();
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
+ } else {
+ constexpr bool use_cp_async = nstages == 1;
+ if (ncols2 > 1 || mask_h) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
+ }
+ }
+
+ // For MLA K and V have the same data.
+ // Therefore, iterate over K in reverse and later re-use the data if possible.
+#pragma unroll
+ for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
+ const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
+ const int k0_diff = k0_stop - k0_start;
+
+ if constexpr (nstages <= 1) {
+ constexpr bool use_cp_async = nstages == 1;
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
+ if (use_cp_async) {
+ cp_async_wait_all();
+ }
+ __syncthreads();
+ }
+
+ // Calculate tile of KQ:
+ if constexpr (Q_in_reg) {
+#pragma unroll
+ for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
+#pragma unroll
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
+ T_A_KQ K_A;
+ load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
+ if constexpr (cols_per_warp == 8) {
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
+ } else {
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
+#else
+ // swap A and B for CUDA.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
+ }
+ }
+ } else {
+#pragma unroll
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
+ load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
+
+#pragma unroll
+ for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
+
+ T_A_KQ K_A;
+ load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
+
+ if constexpr (cols_per_warp == 8) {
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+ } else {
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+ // swap A and B for CUDA.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
+ }
+ }
+ }
+
+ if constexpr (nstages <= 1) {
+ __syncthreads(); // Only needed if tile_K == tile_V.
+ }
+ }
+
+ if (use_logit_softcap) {
+ constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
+ static_assert(nbatch_fa % stride == 0, "bad loop size");
+#pragma unroll
+ for (int i = 0; i < nbatch_fa/stride; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
+ }
+ }
+ }
+
+ float KQ_max_new[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ KQ_max_new[col] = KQ_max[col];
+ }
+ float KQ_rowsum_add[cols_per_thread] = {0.0f};
+
+ if constexpr (cols_per_warp == 8) {
+ if (ncols2 > 1 || mask_h) {
+#pragma unroll
+ for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
+ const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
+#pragma unroll
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ const int i = i0 + T_C_KQ::get_i(l);
+ const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
+
+ KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
+ }
+ }
+ }
+
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+ static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
+#pragma unroll
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
+ // Turing + Volta:
+ const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
+ }
+ }
+ }
+
+ // Values per KQ column are spread across 8 threads:
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+ for (int offset = 16; offset >= 4; offset >>= 1) {
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+ }
+ }
+
+ static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
+#pragma unroll
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
+ // Turing + Volta:
+ const int KQ_idx = l % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
+ KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
+ } else {
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
+ }
+ }
+ }
+ } else { // not Turing mma or T_B_KQ::I > 8
+ if (ncols2 > 1 || mask_h) {
+#pragma unroll
+ for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
+ const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
+#pragma unroll
+ for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
+ const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
+ const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
+
+ const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
+ KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
+ KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
+ }
+ }
+ }
+
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+ static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
+#pragma unroll
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
+ // Turing + Volta:
+ const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
+ }
+ }
+ }
+
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+#if defined(TURING_MMA_AVAILABLE)
+ // Values per KQ column are spread across 4 threads:
+ constexpr int offset_first = 2;
+ constexpr int offset_last = 1;
+#elif defined(AMD_WMMA_AVAILABLE)
+ // Values per KQ column are spread across 2 threads:
+ constexpr int offset_first = 16;
+ constexpr int offset_last = 16;
+#else // Volta
+ // Values per KQ column are spread across 2 threads:
+ constexpr int offset_first = 2;
+ constexpr int offset_last = 2;
+#endif // defined(TURING_MMA_AVAILABLE)
+#pragma unroll
+ for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+ }
+ }
+
+ static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
+#pragma unroll
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
+#if defined(AMD_WMMA_AVAILABLE)
+ constexpr int KQ_idx = 0;
+#else
+ // Turing + Volta:
+ const int KQ_idx = (l/2) % 2;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
+ KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
+ } else {
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
+ }
+ }
+ }
+ }
+
+ {
+ float KQ_max_scale[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ const float KQ_max_diff = KQ_max[col] - KQ_max_new[col];
+ KQ_max_scale[col] = expf(KQ_max_diff);
+ KQ_max[col] = KQ_max_new[col];
+
+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
+ }
+
+#if defined(TURING_MMA_AVAILABLE)
+ if constexpr (cols_per_warp == 8) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
+#pragma unroll
+ for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+ } else {
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
+ VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE)
+ const half2 KQ_max_scale_h2 = make_half2(
+ KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+#else // Volta
+ const half2 KQ_max_scale_h2 = make_half2(
+ KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+#endif // defined(TURING_MMA_AVAILABLE)
+ }
+
+ // Convert KQ C tiles into B tiles for VKQ calculation:
+ T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
+ static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
+ if constexpr (cols_per_warp == 8) {
+#pragma unroll
+ for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
+ B[k] = get_transposed(get_half2(KQ_C[k]));
+ }
+ } else {
+ for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
+ B[k] = get_half2(KQ_C[k]);
+ }
+ }
+
+ if constexpr (nstages > 1) {
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
+ // Preload K tile for next iteration:
+ constexpr bool use_cp_async = true;
+ cp_async_wait_all();
+ __syncthreads();
+ if (!last_iter) {
+ if (ncols2 > 1 || mask_h) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
+ }
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
+ }
+ }
+
+
+#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
+ T_A_VKQ A_identity;
+ make_identity_mat(A_identity);
+#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
+
+ // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
+#pragma unroll
+ for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
+ static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
+ const int i0_stop = i0_start + 2*nbatch_V2;
+ const int i0_diff = i0_stop - i0_start;
+
+ if constexpr (nstages <= 1) {
+ if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
+ constexpr bool use_cp_async = nstages == 1;
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
+ if (use_cp_async) {
+ cp_async_wait_all();
+ }
+ __syncthreads();
+ }
+ }
+ const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
+
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
+#pragma unroll
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
+ static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
+#pragma unroll
+ for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
+ const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
+
+ T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
+#if defined(LDMATRIX_TRANS_AVAILABLE)
+ load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+#else
+ // TODO: Try to transpose tile_V when loading gmem to smem.
+ // Use mma to transpose T_A_VKQ for RDNA.
+ T_A_VKQ A_trans;
+ load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+ mma(A, A_trans, A_identity);
+#endif // defined(TURING_MMA_AVAILABLE)
+ if constexpr (T_B_KQ::I == 8) {
+ mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
+ } else {
+ // Wide version of VKQ_C is column-major.
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
+#else
+ // swap A and B for CUDA.
+ mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
+ }
+ }
+#else // Volta
+ constexpr int i0_stride = 2*T_C_VKQ::J;
+#pragma unroll
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
+ static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
+ static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
+#pragma unroll
+ for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
+ const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
+
+ T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
+ load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+ mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
+ }
+ }
+#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ if constexpr (nstages <= 1) {
+ __syncthreads(); // Only needed if tile_K == tile_V.
+ }
+ }
+#else
+ GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
+ scale, slope, logit_softcap, ne01, ne02,
+ stride_K, stride_V, stride_mask,
+ tile_Q, tile_K, tile_V, tile_mask,
+ Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
+ NO_DEVICE_CODE;
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+}
+
+#if defined(TURING_MMA_AVAILABLE)
+template<int ncols> struct mma_tile_sizes {
+ using T_A_KQ = tile<16, 8, half2>; // row-major
+ using T_B_KQ = tile<16, 8, half2>; // column-major
+ using T_C_KQ = tile<16, 16, float>; // column-major
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
+};
+template<> struct mma_tile_sizes<8> {
+ using T_A_KQ = tile<16, 8, half2>; // row-major
+ using T_B_KQ = tile< 8, 8, half2>; // column-major
+ using T_C_KQ = tile<16, 8, float>; // row-major
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
+ using T_B_VKQ = tile< 8, 8, half2>; // column-major
+ using T_C_VKQ = tile<16, 4, half2>; // row-major
+};
+#elif defined(AMD_WMMA_AVAILABLE)
+template<int ncols> struct mma_tile_sizes {
+ using T_A_KQ = tile<16, 8, half2>; // row-major
+ using T_B_KQ = tile<16, 8, half2>; // column-major
+ using T_C_KQ = tile<16, 16, float>; // column-major
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
+};
+#else // Volta
+template<int ncols> struct mma_tile_sizes {
+ using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
+ using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
+ using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
+ using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
+ using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
+ using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
+};
+#endif // defined(TURING_MMA_AVAILABLE)
+
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
+static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+ const float2 * const __restrict__ Q_f2,
+ const half2 * const __restrict__ K_h2,
+ const half2 * const __restrict__ V_h2,
+ const half * const __restrict__ mask_h,
+ const float * const __restrict__ sinks_f,
+ float2 * const __restrict__ dstk,
+ float2 * const __restrict__ dstk_fixup,
+ const float scale,
+ const float slope,
+ const float logit_softcap,
+ const uint3 ne01,
+ const int ne02,
+ const int gqa_ratio,
+ const int ne11,
+ const int stride_Q1,
+ const int stride_Q2,
+ const int stride_K,
+ const int stride_V,
+ const int stride_mask,
+ const int jt,
+ const int zt_gqa,
+ const int kb0_start,
+ const int kb0_stop) {
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ constexpr int ncols = ncols1 * ncols2;
+ using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
+ using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
+ using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
+ using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
+ using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
+ using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
+
+ constexpr int cols_per_warp = T_B_KQ::I;
+ constexpr int cols_per_thread = get_cols_per_thread();
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
+ constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
+ constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
+ constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
+
+ if (cols_per_warp > ncols) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
+
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = nbatch_K2 + 4;
+
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
+ constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
+
+ extern __shared__ half2 tile_Q[];
+ half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
+ half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
+ half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
+
+ T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
+#if defined(TURING_MMA_AVAILABLE)
+ T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
+#elif defined(AMD_WMMA_AVAILABLE)
+ T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
+#else // Volta
+ T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
+#endif // defined(TURING_MMA_AVAILABLE)
+
+ float KQ_rowsum[cols_per_thread] = {0.0f};
+ float KQ_max[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ KQ_max[col] = -FLT_MAX/2.0f;
+ }
+
+ // Load Q data into tile_Q, either temporarily or permanently.
+ // Q in registers is faster, but register pressure is the biggest bottleneck.
+ // The loading is done with decreasing granularity for D for better memory bandwidth.
+ const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+ const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
+ const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
+ const int stride_jc = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop) {
+ continue;
+ }
+
+#pragma unroll
+ for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
+ const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+ if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
+ break;
+ }
+
+ const int j = jc / ncols2;
+ const int c = jc % ncols2;
+
+ if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
+ tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
+ }
+ } else {
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
+ }
+ }
+ }
+ }
+
+ __syncthreads();
+
+ if (Q_in_reg) {
+ const int j0 = (threadIdx.y / np) * cols_per_warp;
+
+#pragma unroll
+ for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
+ load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
+ }
+ }
+
+ __syncthreads();
+
+ int kb0 = kb0_start;
+
+ // Preload mask and K data for first iteration when using cp_async with multiple stages:
+ if constexpr (nstages > 1) {
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
+ constexpr bool use_cp_async = true;
+ constexpr bool oob_check = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ if (ncols2 > 1 || mask_h) {
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
+ }
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
+ (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
+ }
+
+ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
+ if constexpr (ncols2 == 1) {
+ constexpr bool oob_check = true;
+ for (; kb0 < kb0_stop-1; ++kb0) {
+ constexpr bool last_iter = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ }
+ constexpr bool last_iter = true;
+ const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ } else {
+ constexpr bool oob_check = false;
+ for (; kb0 < kb0_stop-1; ++kb0) {
+ constexpr bool last_iter = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ }
+ constexpr bool last_iter = true;
+ constexpr int k_VKQ_sup = nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ }
+
+ // With multi-stage loading there is no __syncthreads at the end of the iter,
+ // there can be a race condition on shared memory access for combining/writing back results.
+ if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
+ __syncthreads();
+ }
+
+ // Finally, sum up partial KQ rowsums.
+ {
+#if defined(TURING_MMA_AVAILABLE)
+ // The partial sums are spread across 8/4 threads.
+ constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
+ constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
+#elif defined(AMD_WMMA_AVAILABLE)
+ // The partial sums are spread across 2 threads.
+ constexpr int offset_first = 16;
+ constexpr int offset_last = 16;
+#else // Volta
+ // The partial sums are spread across 2 threads.
+ constexpr int offset_first = 2;
+ constexpr int offset_last = 2;
+#endif // defined(TURING_MMA_AVAILABLE)
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+#pragma unroll
+ for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
+ KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+ }
+ }
+ }
+
+ // If attention sinks are used, potentially re-scale if KQ_max is small.
+ // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
+ // so it's being done unconditionally for every thread.
+ if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
+ float KQ_max_scale[cols_per_thread];
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
+ const float sink = sinks_f[jc % ncols2];
+
+ const float KQ_max_new = fmaxf(KQ_max[col], sink);
+ const float KQ_max_diff = KQ_max[col] - KQ_max_new;
+ KQ_max_scale[col] = expf(KQ_max_diff);
+ KQ_max[col] = KQ_max_new;
+
+ *((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
+
+ const float KQ_max_add = expf(sink - KQ_max_new);
+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
+ }
+
+#if defined(TURING_MMA_AVAILABLE)
+ if constexpr (cols_per_warp == 8) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
+#pragma unroll
+ for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+ } else {
+#pragma unroll
+ for (int col = 0; col < cols_per_thread; ++col) {
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
+ VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE)
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+#else // Volta
+ const int col = (threadIdx.x / 2) % 2;
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
+#pragma unroll
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
+ }
+ }
+#endif // defined(TURING_MMA_AVAILABLE)
+ }
+
+ // Combine VKQ accumulator values if np > 1.
+ // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
+ // So also write VKQ accumulators to shared memory in column-major format if np == 1.
+
+ constexpr int tile_stride = nbatch_combine + 4;
+ static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
+
+ if constexpr (cols_per_warp == 8) {
+ const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
+ const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
+ const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
+
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
+ // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
+ ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
+ }
+
+ __syncthreads();
+
+ if (np == 1) {
+ // No combination is needed, the meta data can be directly written from registers to VRAM.
+ if (needs_fixup && threadIdx.x < T_B_KQ::I) {
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ if (is_fixup && threadIdx.x < T_B_KQ::I) {
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ }
+ } else {
+ // jc_cwm = jc combine write meta
+ // KQ_cmr = KQ combine max rowsum
+ // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
+#if defined(TURING_MMA_AVAILABLE)
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
+ const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
+ const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
+#elif defined(AMD_WMMA_AVAILABLE)
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
+ const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
+ const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
+#else // Volta
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
+ const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
+ const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
+#endif // defined(TURING_MMA_AVAILABLE)
+
+ if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
+ ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
+ }
+
+ __syncthreads();
+
+ if (np == 1) {
+ // No combination is needed, the meta data can be directly written from registers to VRAM.
+ if (needs_fixup && thread_should_write) {
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ if (is_fixup && thread_should_write) {
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
+ }
+ }
+ }
+
+ if (np > 1 && threadIdx.y % np == 0) {
+ // Combine the meta data for parallel warps via shared memory.
+ // Warps with threadIdx.y % np != 0 must NOT return early.
+ // All threads must return simultaneously to avoid race conditions with work on the next tile.
+
+ constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+
+ const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+ float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
+ float2 meta[nmeta];
+#pragma unroll
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
+ meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
+ }
+
+ float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
+#pragma unroll
+ for (int imeta = 1; imeta < nmeta; ++imeta) {
+ KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
+ }
+#pragma unroll
+ for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+ if (offset < WARP_SIZE) {
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+ }
+ }
+
+ float KQ_cms[nmeta]; // KQ combine max scale per warp.
+#pragma unroll
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
+ KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
+ }
+
+ float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
+#pragma unroll
+ for (int imeta = 1; imeta < nmeta; ++imeta) {
+ KQ_crs += KQ_cms[imeta]*meta[imeta].y;
+ }
+#pragma unroll
+ for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
+ if (offset < WARP_SIZE) {
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+ }
+ }
+
+ __syncthreads();
+
+ // Write back combined meta data:
+#pragma unroll
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
+ if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+ // Combined KQ max scale + rowsum.
+ meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
+ }
+ }
+
+ // Combined KQ max + rowsum.
+ static_assert(cols_per_warp <= WARP_SIZE);
+ if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
+ dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+ }
+ if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
+ dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
+ }
+ } else if (np > 1) {
+ // Warps with threadIdx.y % np == 0 execute a __syncthreads() in the if branch.
+ // Therefore, all other warps also need to execute a __syncthreads().
+ // Otherwise the points at which warps synchronize with each other would become misaligned.
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
+ if constexpr (cols_per_warp == 8) {
+ const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
+#pragma unroll
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
+ const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
+
+#pragma unroll
+ for (int l = 0; l < T_B_KQ::ne; ++l) {
+ const int k = k1 + T_B_KQ::get_j(l);
+
+ tile_Q[jc_cwd*tile_stride + k] = B.x[l];
+ }
+ }
+ } else {
+ const int j0 = threadIdx.y*cols_per_warp;
+#pragma unroll
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
+#pragma unroll
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
+ const int j = j0 + T_C_VKQ::get_i(l);
+ const int k = k1 + T_C_VKQ::get_j(l);
+
+ tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
+ }
+ }
+ }
+
+ __syncthreads();
+
+ if (np == 1 || threadIdx.y % np == 0) {
+ // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
+ // The values after that are for the partial results of the individual blocks.
+ float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
+
+#pragma unroll
+ for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
+ const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
+ const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
+ const int stride_jc = WARP_SIZE / stride_k;
+
+ if (k0_start == k0_stop) {
+ continue;
+ }
+
+#pragma unroll
+ for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
+ const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+
+ if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
+ break;
+ }
+
+ const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
+
+ const int j_dst = jc_dst / ncols2;
+ const int c_dst = jc_dst % ncols2;
+
+ if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
+ continue;
+ }
+
+ const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
+#pragma unroll
+ for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
+ const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+
+ float2 dstk_val = make_float2(0.0f, 0.0f);
+#pragma unroll
+ for (int ip = 0; ip < np; ++ip) {
+ const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * tile_stride + 0];
+ const float2 dstk_val_add = __half22float2(tile_Q[(jc_tile_K + ip*cols_per_warp) * tile_stride + k]);
+ dstk_val.x += dstk_val_add.x*KQ_crs;
+ dstk_val.y += dstk_val_add.y*KQ_crs;
+ }
+
+ if (!needs_fixup && !is_fixup) {
+ const float KQ_rowsum_j = meta_j[1];
+ dstk_val.x /= KQ_rowsum_j;
+ dstk_val.y /= KQ_rowsum_j;
+ }
+
+ if (is_fixup) {
+ dstk_fixup_data[jc_dst*(DV/2) + k00 + k] = dstk_val;
+ } else {
+ dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(DV/2) + k00 + k] = dstk_val;
+ }
+ }
+ }
+ }
+ }
+ if (np > 1) {
+ __syncthreads();
+ }
+ }
+#else
+ GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
+ scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
+ stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
+ jt, kb0_start, kb0_stop);
+ NO_DEVICE_CODE;
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+}
+
+template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
+__launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
+static __global__ void flash_attn_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ const char * __restrict__ sinks,
+ const int * __restrict__ KV_max,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
+
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#ifdef VOLTA_MMA_AVAILABLE
+ if (ncols1*ncols2 < 32) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // VOLTA_MMA_AVAILABLE
+
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ if (ncols1*ncols2 > 32) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+
+#if defined(AMD_WMMA_AVAILABLE)
+ if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
+ constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
+ constexpr int nwarps = nthreads / WARP_SIZE;
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+ const int stride_Q1 = nb01 / sizeof(float2);
+ const int stride_Q2 = nb02 / sizeof(float2);
+ const int stride_K = nb11 / sizeof(half2);
+ const int stride_mask = nb31 / sizeof(half);
+
+ const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
+
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+ const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
+
+ // kbc == k block continuous, current index in continuous ijk space.
+ int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+ const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+
+ // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
+ // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
+ // In the most general case >2 seams can fall into the same tile.
+
+ // kb0 == k start index when in the output tile.
+ int kb0_start = kbc % iter_k;
+ int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
+
+ while (kbc < kbc_stop && kb0_stop == iter_k) {
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
+ const half * mask_h = ncols2 == 1 && !mask ? nullptr :
+ (const half *) (mask + nb33*(sequence % ne33));
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
+
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
+
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
+
+ if (KV_max) {
+ kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
+ }
+ constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+ if (kb0_start == 0) {
+ constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
+ } else {
+ constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
+ }
+
+ kbc += iter_k;
+ kbc -= kbc % iter_k;
+
+ kb0_start = 0;
+ kb0_stop = min(iter_k, kbc_stop - kbc);
+ }
+
+ if (kbc >= kbc_stop) {
+ return;
+ }
+
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
+ const half * mask_h = ncols2 == 1 && !mask ? nullptr :
+ (const half *) (mask + nb33*(sequence % ne33));
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
+
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
+
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
+
+ if (KV_max) {
+ kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
+ }
+
+ constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
+ constexpr bool needs_fixup = false;
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
+#else
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
+}
+
+template <int DKQ, int DV, int ncols1, int ncols2>
+void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+
+ constexpr int ncols = ncols1 * ncols2;
+
+ const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc);
+ const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc);
+ const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc);
+ const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc);
+ const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
+ const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
+ const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
+
+ const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
+ const int nwarps = nthreads / WARP_SIZE;
+
+ constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
+
+ const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
+
+ const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
+
+ const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
+ std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
+ nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
+
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+#if defined(GGML_USE_HIP)
+ using fattn_kernel_ptr_t = const void*;
+#else
+ using fattn_kernel_ptr_t = fattn_kernel_t;
+#endif // defined(GGML_USE_HIP)
+ fattn_kernel_t fattn_kernel;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
+
+#if !defined(GGML_USE_MUSA)
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+ if (!shared_memory_limit_raised[id]) {
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+ shared_memory_limit_raised[id] = true;
+ }
+#endif // !defined(GGML_USE_MUSA)
+ } else {
+ constexpr bool use_logit_softcap = true;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
+
+#if !defined(GGML_USE_MUSA)
+ static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
+ if (!shared_memory_limit_raised[id]) {
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
+ shared_memory_limit_raised[id] = true;
+ }
+#endif // !defined(GGML_USE_MUSA)
+ }
+
+ launch_fattn<DV, ncols1, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
+}
+
+
+#define DECL_FATTN_MMA_F16_CASE(DKQ, DV, ncols1, ncols2) \
+ template void ggml_cuda_flash_attn_ext_mma_f16_case \
+ <DKQ, DV, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(DKQ, DV, ncols) \
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 1, 1); \
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 2, 2); \
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 4, 4); \
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/ 8, 8); \
+ extern DECL_FATTN_MMA_F16_CASE(DKQ, DV, (ncols)/16, 16); \
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 8)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 8)
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 16)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 16)
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 32)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 32)
+
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 80, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 96, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
+DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
+
+// The number of viable configurations for Deepseek is very limited:
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu
new file mode 100644
index 0000000..3fcb09b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cu
@@ -0,0 +1,49 @@
+#include "common.cuh"
+#include "fattn-tile.cuh"
+#include "fattn-wmma-f16.cuh"
+
+void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ switch (K->ne[0]) {
+ case 40: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
+ } break;
+ case 64: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
+ } break;
+ case 72: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
+ } break;
+ case 80: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
+ } break;
+ case 96: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
+ } break;
+ case 112: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
+ } break;
+ case 128: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
+ } break;
+ case 256: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
+ } break;
+ case 576: {
+ GGML_ASSERT(V->ne[0] == 512);
+ ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
+ } break;
+ default: {
+ GGML_ABORT("Unsupported head size");
+ } break;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh
new file mode 100644
index 0000000..b6db582
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -0,0 +1,1256 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-wmma-f16.cuh"
+
+// nbatch_fa == number of KQ rows to process per iteration
+// nbatch_K == number of K columns to load in parallel for KQ calculation
+
+// TODO optimize kernel parameters for FP16 NVIDIA (P100)
+// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
+
+// The ROCm compiler cannot handle templating in __launch_bounds__.
+// As a workaround, define a macro to package the kernel parameters as uint32_t:
+#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
+ if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
+ static_assert((nthreads) <= 512, "bad nthreads"); \
+ static_assert((occupancy) <= 8, "bad occupancy"); \
+ static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \
+ static_assert((nbatch_K) <= 256, "bad nbatch_K"); \
+ return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \
+ } \
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+
+ return 0;
+}
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
+
+ return 0;
+}
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
+
+ return 0;
+}
+
+static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
+
+ return 0;
+}
+
+static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
+ if (GGML_CUDA_CC_IS_RDNA(cc)) {
+ return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
+ }
+ return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
+ }
+ if (fast_fp16_available(cc)) {
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
+ }
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
+}
+
+static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
+#ifdef GGML_USE_HIP
+#ifdef RDNA
+ return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols);
+#else
+ return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols);
+#endif // RDNA
+#else
+#ifdef FAST_FP16_AVAILABLE
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols);
+#else
+ return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols);
+#endif // FAST_FP16_AVAILABLE
+#endif // GGML_USE_HIP
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
+}
+
+static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
+ return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
+}
+
+// TODO: deduplicate with mma-f16
+template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __device__ __forceinline__ void flash_attn_tile_load_tile(
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ auto load = [&] __device__ (const int n) {
+ const int stride_j = warp_size >> n;
+
+ if (stride_j == 0) {
+ return;
+ }
+
+ const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
+ const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
+ const int stride_i = warp_size / stride_j;
+
+ if (j0_start == j0_stop) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
+
+ if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+ for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+ const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
+
+ const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
+ ggml_cuda_memcpy_1<cpy_nb>(
+ tile_KV + i*(J/2 + J_padding) + j,
+ !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+ }
+ }
+ }
+ };
+ // 1: max 64*16=512 bytes, 512 half
+ // 2: max 32*16=512 bytes, 256 half
+ // 3: max 16*16=256 bytes, 128 half
+ // 4: max 8*16=128 bytes, 64 half
+ // 5: max 4*16= 64 bytes, 32 half
+ // 6: max 2*16= 32 bytes, 16 half
+ // 7: max 1*16= 16 bytes, 8 half
+ static_assert(J % 8 == 0, "bad J");
+ static_assert((J/2) % cpy_ne == 0, "bad J");
+ ggml_cuda_unroll<7>{}(load);
+}
+
+template<int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __device__ __forceinline__ void flash_attn_tile_load_tile(
+ const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) {
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ auto load = [&] __device__ (const int n) {
+ const int stride_j = warp_size >> n;
+
+ if (stride_j == 0) {
+ return;
+ }
+
+ const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
+ const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
+ const int stride_i = warp_size / stride_j;
+
+ if (j0_start == j0_stop) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+ const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j);
+
+ if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+ for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+ const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
+
+ const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
+ __align__(16) half2 tmp_h2[cpy_ne/2];
+ ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
+ tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+
+ __align__(16) float2 tmp_f2[cpy_ne/2];
+#pragma unroll
+ for (int l = 0; l < cpy_ne/2; ++l) {
+ tmp_f2[l] = __half22float2(tmp_h2[l]);
+ }
+ ggml_cuda_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
+ }
+ }
+ }
+ };
+ // 1: max 32*16=512 bytes, 128 float
+ // 2: max 16*16=256 bytes, 64 float
+ // 3: max 8*16=128 bytes, 32 float
+ // 4: max 4*16= 64 bytes, 16 float
+ // 5: max 2*16= 32 bytes, 8 float
+ static_assert(J % 8 == 0, "bad J");
+ static_assert(J % cpy_ne == 0, "bad J");
+ ggml_cuda_unroll<5>{}(load);
+}
+
+// Function that performs a single iteration in for the KQ matrix multiplication:
+template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int nbatch_fa, int nbatch_K,
+ bool use_logit_softcap, bool oob_check, typename T_vec_dot>
+static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
+ T_vec_dot * const Q_tmp,
+ const half2 * const __restrict__ K_h2,
+ T_vec_dot * const KV_tmp,
+ const int stride_K2,
+ const int k_VKQ_0,
+ const int k_VKQ_sup,
+ const int k_KQ_0,
+ float * KQ_acc) {
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ constexpr int ncols = ncols1*ncols2;
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+ flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
+ (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
+ __syncthreads();
+
+#ifdef FAST_FP16_AVAILABLE
+ static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+ for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
+ __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+ __align__(16) half2 Q_k[cpw][cpy_ne];
+#else
+ static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+ for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
+ __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+ __align__(16) float Q_k[cpw][cpy_ne];
+#endif // FAST_FP16_AVAILABLE
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+ const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
+
+#ifdef FAST_FP16_AVAILABLE
+ ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
+#else
+ ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]);
+#endif // FAST_FP16_AVAILABLE
+ }
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (threadIdx.y / np)*cpw;
+
+#ifdef FAST_FP16_AVAILABLE
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
+#else
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]);
+#endif // FAST_FP16_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+#pragma unroll
+ for (int k = 0; k < cpy_ne; ++k) {
+ ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
+ }
+ }
+ }
+ }
+
+ if (k_KQ_0 + nbatch_K < DKQ) {
+ __syncthreads(); // Sync not needed on last iteration.
+ }
+}
+
+// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
+template <int warp_size, int nwarps, int ncols1, int ncols2, int DKQ, int DV, int nbatch_fa, int nbatch_K,
+ bool use_logit_softcap, bool oob_check, typename T_vec_dot, typename T_KQ, typename T_acc>
+static __device__ __forceinline__ void flash_attn_tile_iter(
+ T_vec_dot * const Q_tmp,
+ const half2 * const __restrict__ K_h2,
+ const half2 * const __restrict__ V_h2,
+ const half * const __restrict__ mask,
+ const uint3 ne01,
+ const float logit_softcap,
+ const float slope,
+ T_KQ * const KQ,
+ T_vec_dot * const KV_tmp,
+ const int stride_K2,
+ const int stride_V2,
+ const int stride_mask,
+ float * const KQ_max,
+ float * const KQ_sum,
+ T_acc * const VKQ,
+ const int k_VKQ_0,
+ const int k_VKQ_max,
+ const int col_Q_0) {
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ constexpr int ncols = ncols1*ncols2;
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+ constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
+
+ // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory.
+ // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs].
+#ifdef FAST_FP16_AVAILABLE
+ constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
+#else
+ constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
+#endif // FAST_FP16_AVAILABLE
+ static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
+ const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
+
+ float KQ_max_new[cpw];
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ KQ_max_new[jc0] = KQ_max[jc0];
+ }
+
+ float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
+
+ // KQ = K @ Q matrix multiplication:
+ constexpr int nbatch_K_last = DKQ % nbatch_K;
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
+ flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+ }
+ if (nbatch_K_last > 0) {
+ constexpr int k_KQ_0 = DKQ - nbatch_K_last;
+ flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+ }
+
+ // Apply logit softcap + mask, update KQ_max:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int j = fastmodulo(col_Q_0 + (jc0 + (threadIdx.y / np)*cpw)/ncols2, ne01);
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+ const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
+
+#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
+ // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+ // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+ KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
+#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
+
+ if (use_logit_softcap) {
+ KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
+ }
+
+ if (!oob_check || i_KQ < k_VKQ_sup) {
+ KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
+ slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
+
+ KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] + FATTN_KQ_MAX_OFFSET);
+ }
+ }
+
+ KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
+ }
+
+ if constexpr (np == 1) {
+ __syncthreads();
+ } else {
+ static_assert(cpw == 1, "bad cpw");
+ __shared__ float KQ_max_new_shared[nwarps];
+ if (threadIdx.x == 0) {
+ KQ_max_new_shared[threadIdx.y] = KQ_max_new[0];
+ }
+ __syncthreads();
+ KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np];
+ KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
+ }
+
+ // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
+#ifdef FAST_FP16_AVAILABLE
+ __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+#else
+ __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+#endif // FAST_FP16_AVAILABLE
+
+#pragma unroll
+ for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
+ const int jc = jc0 + jc1;
+
+ const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]);
+ KQ_max[jc] = KQ_max_new[jc];
+
+ float KQ_sum_add = 0.0f;
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+ const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
+ expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
+ KQ_sum_add += val;
+ tmp[i0/(np*warp_size)][jc1] = val;
+ }
+ KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
+
+#ifdef FAST_FP16_AVAILABLE
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
+ }
+#else
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
+ }
+#endif // FAST_FP16_AVAILABLE
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+ const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x;
+
+ ggml_cuda_memcpy_1<sizeof(tmp[0])>(
+ KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs,
+ tmp[i0/(np*warp_size)]);
+ }
+ }
+
+ // VKQ = V @ KQ matrix multiplication:
+ static_assert(DV <= DKQ, "bad DV");
+ static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
+ constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
+ static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
+ static_assert(nbatch_V % np == 0, "bad nbatch_V");
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
+ flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
+ (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
+ __syncthreads();
+
+#ifdef FAST_FP16_AVAILABLE
+#pragma unroll
+ for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+ __align__(16) half2 V_k[(DVp/2)/warp_size];
+ __align__(16) half2 KQ_k[cpw];
+
+ constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]);
+ }
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+ const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
+
+ __align__(16) half tmp[KQ_cs];
+ ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
+ &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
+#pragma unroll
+ for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
+ KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]);
+ }
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0];
+ }
+ }
+ }
+#else
+#pragma unroll
+ for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+ __align__(16) float2 V_k[(DVp/2)/warp_size];
+ __align__(16) float KQ_k[cpw];
+
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]);
+ }
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+ const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
+
+ ggml_cuda_memcpy_1<KQ_cs*sizeof(float)>(
+ &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0];
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0];
+ }
+ }
+ }
+#endif // FAST_FP16_AVAILABLE
+
+ __syncthreads();
+ }
+}
+
+template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size
+__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2))
+static __global__ void flash_attn_tile(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ const char * __restrict__ sinks,
+ const int * __restrict__ KV_max,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
+#ifdef FLASH_ATTN_AVAILABLE
+
+ // Skip unused kernel variants for faster compilation:
+
+ if (
+#ifdef GGML_USE_WMMA_FATTN
+ (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
+#endif // GGML_USE_WMMA_FATTN
+ (use_logit_softcap && !(DV == 128 || DV == 256))
+ ) {
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
+
+ constexpr int ncols = ncols1*ncols2;
+ constexpr int warp_size = 32;
+ constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
+ constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
+ constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
+
+ // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on.
+
+ const int sequence = blockIdx.z / (ne02/ncols2);
+ const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2)
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+ const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape
+
+ const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33)) : nullptr;
+
+ const int stride_K2 = nb11 / sizeof(half2);
+ const int stride_V2 = nb21 / sizeof(half2);
+ const int stride_mask = nb31 / sizeof(half);
+
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
+ static_assert(cpw == 1 || np == 1, "bad cpw / np");
+ static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
+
+ constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
+ constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
+
+ // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
+ // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
+ // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
+ // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
+ // VKQ == Accumulators in registers for the final VKQ result.
+#ifdef FAST_FP16_AVAILABLE
+ __shared__ half2 Q_tmp[ncols * DKQ/2];
+ __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
+ __shared__ half KQ[ncols * nbatch_fa];
+ __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+#else
+ __shared__ float Q_tmp[ncols * DKQ];
+ __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
+ __shared__ float KQ[ncols * nbatch_fa];
+ __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+#endif // FAST_FP16_AVAILABLE
+
+ float KQ_max[cpw];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
+ }
+ float KQ_sum[cpw] = {0.0f};
+
+ // Load Q data, convert to FP16 if fast:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (threadIdx.y / np)*cpw;
+
+ const int j = jc / ncols2;
+ const int c = jc % ncols2;
+
+ constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
+
+#pragma unroll
+ for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
+ if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
+ __align__(16) float tmp_f[cpy_ne_D] = {0.0f};
+ ggml_cuda_memcpy_1<sizeof(tmp_f)>
+ (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
+ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
+
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ tmp_f[i1] *= scale;
+ }
+
+#ifdef FAST_FP16_AVAILABLE
+ __align__(16) half2 tmp_h2[cpy_ne_D/2];
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
+ tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
+#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
+ // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+ // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+ tmp_h2[i1/2] *= make_half2(0.25f, 0.25f);
+#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
+ }
+ ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
+ &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
+ tmp_h2);
+#else
+ ggml_cuda_memcpy_1<sizeof(tmp_f)>(
+ &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D],
+ tmp_f);
+#endif // FAST_FP16_AVAILABLE
+ }
+ }
+ }
+
+ __syncthreads();
+
+ // Main loop over KV cache:
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+ if (ncols2 == 1) {
+ // Branch with out-of-bounds checks.
+ int k_VKQ_0 = blockIdx.y*nbatch_fa;
+ while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
+ constexpr bool oob_check = false;
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
+ (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
+ k_VKQ_0 += gridDim.y*nbatch_fa;
+ }
+ if (k_VKQ_0 < k_VKQ_max) {
+ constexpr bool oob_check = true;
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
+ (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
+ }
+ } else {
+ // Branch without out-of-bounds checks.
+ for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) {
+ constexpr bool oob_check = false;
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>
+ (Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp,
+ stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0);
+ }
+ }
+
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
+ }
+
+ if constexpr (np > 1) {
+ static_assert(cpw == 1, "bad cpw");
+ static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
+
+#ifdef FAST_FP16_AVAILABLE
+ half2 * VKQ_combine = (half2 *) KV_tmp;
+#else
+ float * VKQ_combine = (float *) KV_tmp;
+#endif // FAST_FP16_AVAILABLE
+ float * KQ_sum_combine = (float *) Q_tmp;
+
+ if (threadIdx.y % np != 0) {
+#ifdef FAST_FP16_AVAILABLE
+ constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]);
+ }
+#else
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(
+ &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
+ }
+#endif // FAST_FP16_AVAILABLE
+
+ if (threadIdx.x == 0) {
+ KQ_sum_combine[threadIdx.y] = KQ_sum[0];
+ }
+
+ return;
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int ip = 1; ip < np; ++ip) {
+#ifdef FAST_FP16_AVAILABLE
+ constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ __align__(16) half2 tmp[cpy_ne_D];
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ VKQ[i0/warp_size + i1] += tmp[i1];
+ }
+ }
+#else
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ __align__(16) float tmp[cpy_ne_D];
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
+ }
+ }
+#endif // FAST_FP16_AVAILABLE
+
+ KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip];
+ }
+ }
+
+ // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
+ if (sinks && blockIdx.y == 0) {
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (threadIdx.y/np)*cpw;
+ const float sink = ((const float *) sinks)[head0 + jc % ncols2];
+
+ float KQ_max_new_j = fmaxf(KQ_max[jc0], sink);
+ const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j);
+ KQ_max[jc0] = KQ_max_new_j;
+
+ const float val = expf(sink - KQ_max[jc0]);
+ KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
+
+#ifdef FAST_FP16_AVAILABLE
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
+ }
+#else
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale;
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale;
+ }
+#endif // FAST_FP16_AVAILABLE
+ }
+ }
+
+ // Write back results:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (threadIdx.y/np)*cpw;
+
+ const int j = jc / ncols2;
+ const int c = jc % ncols2;
+
+ if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z)) {
+ return;
+ }
+
+ const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
+
+ const int j_dst_unrolled = ((sequence*int(ne01.z) + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
+
+#ifdef FAST_FP16_AVAILABLE
+ constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ __align__(16) float2 tmp[cpy_ne_D];
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
+ tmp[i1].x *= scale;
+ tmp[i1].y *= scale;
+ }
+ if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
+ ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
+ }
+ }
+#else
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
+ VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
+ VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
+ }
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(
+ &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
+ &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
+ }
+ }
+#endif // FAST_FP16_AVAILABLE
+
+ if (gridDim.y != 1 && threadIdx.x == 0) {
+ dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
+ }
+ }
+#else
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
+}
+
+template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int warp_size = 32;
+
+ constexpr size_t nbytes_shared = 0;
+
+#ifdef GGML_USE_HIP
+ if constexpr (DV <= 128) {
+ if (Q->ne[1] > 32/ncols2) {
+ constexpr int cols_per_block = 64;
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+ return;
+ }
+ }
+#endif // GGML_USE_HIP
+
+#ifndef GGML_USE_HIP
+ if constexpr (DV <= 256)
+#endif // GGML_USE_HIP
+ {
+ if (Q->ne[1] > 16/ncols2) {
+ constexpr int cols_per_block = 32;
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+ return;
+ }
+ }
+
+ if (Q->ne[1] > 8/ncols2) {
+ constexpr int cols_per_block = 16;
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+ return;
+ }
+
+ if constexpr (ncols2 <= 8) {
+ if (Q->ne[1] > 4/ncols2) {
+ constexpr int cols_per_block = 8;
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+ return;
+ }
+ }
+
+ if constexpr (ncols2 <= 4) {
+ if (Q->ne[1] > 2/ncols2) {
+ constexpr int cols_per_block = 4;
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+ return;
+ }
+ }
+
+ if constexpr (ncols2 <= 2) {
+ constexpr int cols_per_block = 2;
+ const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>;
+ launch_fattn<DV, cols_per_block/ncols2, ncols2>
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size);
+ return;
+ }
+
+ GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * mask = dst->src[3];
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+ const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc);
+ const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX;
+ const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+ if constexpr (DV == 512) {
+ if (use_gqa_opt && gqa_ratio % 16 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ }
+
+ if constexpr (DV <= 256) {
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
+ return;
+ }
+
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV>
+void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+ }
+}
+
+void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+#define DECL_FATTN_TILE_CASE(DKQ, DV) \
+ template void ggml_cuda_flash_attn_ext_tile_case \
+ <DKQ, DV>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_TILE_CASE( 40, 40);
+extern DECL_FATTN_TILE_CASE( 64, 64);
+extern DECL_FATTN_TILE_CASE( 72, 72);
+extern DECL_FATTN_TILE_CASE( 80, 80);
+extern DECL_FATTN_TILE_CASE( 96, 96);
+extern DECL_FATTN_TILE_CASE(112, 112);
+extern DECL_FATTN_TILE_CASE(128, 128);
+extern DECL_FATTN_TILE_CASE(256, 256);
+extern DECL_FATTN_TILE_CASE(576, 512);
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh
new file mode 100644
index 0000000..3f4a78c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn-vec.cuh
@@ -0,0 +1,586 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+
+static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
+ return 128;
+ GGML_UNUSED(cc);
+}
+
+static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
+ return 128;
+}
+
+// Currenlty llvm with the amdgcn target does not support unrolling loops
+// that contain a break that can not be resolved at compile time.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
+__launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
+static __global__ void flash_attn_ext_vec(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ const char * __restrict__ sinks,
+ const int * __restrict__ KV_max,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
+#ifdef FLASH_ATTN_AVAILABLE
+
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+#ifdef GGML_USE_HIP
+#ifdef RDNA
+ constexpr int nthreads_KQ_q = 2;
+#else
+ constexpr int nthreads_KQ_q = 4;
+#endif // RDNA
+ constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
+#else
+ constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32);
+ constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
+#endif // GGML_USE_HIP
+
+ constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
+ constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
+ constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
+
+ static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
+ static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
+
+ constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
+ constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
+
+ constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
+ constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
+#else
+ constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
+#endif // V_DOT2_F32_F16_AVAILABLE
+
+ const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
+
+ const int sequence = blockIdx.z / ne02;
+ const int head = blockIdx.z - sequence*ne02;
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ Q += nb03*sequence + nb02* head + nb01*ic0;
+ K += nb13*sequence + nb12*(head / gqa_ratio);
+ V += nb23*sequence + nb22*(head / gqa_ratio);
+
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+
+ const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
+
+ static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
+ constexpr int nwarps = nthreads / WARP_SIZE;
+ const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
+ __builtin_assume(tid < nthreads);
+
+ constexpr int ne_KQ = ncols*D;
+ constexpr int ne_combine = nwarps*V_cols_per_iter*D;
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
+ __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
+#else
+ float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
+ __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
+#endif // V_DOT2_F32_F16_AVAILABLE
+
+ float KQ_max[ncols];
+ float KQ_sum[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_max[j] = -FLT_MAX/2.0f;
+ KQ_sum[j] = 0.0f;
+ }
+
+ // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
+#else
+ __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
+#endif // V_DOT2_F32_F16_AVAILABLE
+ int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
+ float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
+ if constexpr (Q_q8_1) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > ncols && j >= ncols) {
+ break;
+ }
+
+ // Reuse KQ as temporary storage for converting Q to q8_1:
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+ // Set memory to zero if out of bounds:
+ if (ncols > 1 && ic0 + j >= int(ne01.z)) {
+#pragma unroll
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
+ tmp_q_i32[i] = 0;
+ }
+ }
+ if (threadIdx.x < D/QK8_1) {
+ tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
+ }
+ } else {
+ const float * Q_f = (const float *) (Q + j*nb01);
+ constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE;
+#pragma unroll
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
+ quantize_q8_1_to_shared<float2, nthreads_quantize>
+ (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
+ }
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
+
+#pragma unroll
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
+ const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ);
+
+ Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
+ Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
+ }
+ }
+
+ __syncthreads();
+ } else {
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ const half2 scale_h2 = make_half2(scale, scale);
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ const float2 * Q_j = (const float2 *) (Q + j*nb01);
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+ const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
+
+ __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
+ if (ncols == 1 || ic0 + j < int(ne01.z)) {
+ ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
+ ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
+ }
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne; ++i1) {
+ Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y);
+ }
+ }
+#pragma unroll
+ for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+ Q_reg[j][k] *= scale_h2;
+ }
+ }
+#else
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ const float2 * Q_j = (const float2 *) (Q + j*nb01);
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+ const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
+ if (ncols == 1 || ic0 + j < int(ne01.z)) {
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
+ }
+ }
+#pragma unroll
+ for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+ Q_reg[j][k].x *= scale;
+ Q_reg[j][k].y *= scale;
+ }
+ }
+#endif // V_DOT2_F32_F16_AVAILABLE
+ }
+
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+ K += blockIdx.y*nthreads * nb11;
+ V += blockIdx.y*nthreads * nb21;
+ maskh += blockIdx.y*nthreads;
+ for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads,
+ // Increment pointers after each loop:
+ K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) {
+
+ // Calculate KQ tile and keep track of new maximum KQ values:
+ float KQ_reg[ncols]; // KQ in registers.
+
+ float KQ_max_new[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_max_new[j] = KQ_max[j];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
+ const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
+ sum = warp_reduce_sum<nthreads_KQ>(sum);
+
+ if (use_logit_softcap) {
+ sum = logit_softcap*tanhf(sum);
+ }
+
+ if (mask && (ncols == 1 || ic0 + j < int(ne01.z))) {
+ sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
+ }
+
+ KQ_max_new[j] = fmaxf(KQ_max_new[j], sum + FATTN_KQ_MAX_OFFSET);
+
+ if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
+ KQ_reg[j] = sum;
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+ for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
+ KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
+ }
+ const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
+ KQ_max[j] = KQ_max_new[j];
+
+ KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
+ KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
+ KQ[j*nthreads + tid] = KQ_reg[j];
+
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+ }
+#else
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
+ VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
+ }
+#endif // V_DOT2_F32_F16_AVAILABLE
+ }
+
+#ifndef GGML_USE_HIP
+ __syncwarp();
+#endif // GGML_USE_HIP
+
+#pragma unroll
+ for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
+ const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
+
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ half2 KQ_k[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ half2 tmp[V_rows_per_thread/2];
+ dequantize_V(V + k*nb21, tmp,
+ 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
+#pragma unroll
+ for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
+ }
+ }
+ }
+#else
+ float KQ_k[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_k[j] = KQ[j*nthreads + k];
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ float2 tmp[V_rows_per_thread/2];
+ dequantize_V(V + k*nb21, tmp,
+ 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
+#pragma unroll
+ for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
+ }
+ }
+ }
+#endif // V_DOT2_F32_F16_AVAILABLE
+ }
+ }
+
+ if (sinks && blockIdx.y == 0) {
+ const float sink = ((const float *) sinks)[head];
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > ncols && j >= ncols) {
+ break;
+ }
+
+ const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
+ const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
+ KQ_max[j] = kqmax_new_j;
+
+ KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
+
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+ }
+#else
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
+ VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
+ }
+#endif // V_DOT2_F32_F16_AVAILABLE
+ }
+ }
+
+ __shared__ float KQ_max_shared[ncols][WARP_SIZE];
+ __shared__ float KQ_sum_shared[ncols][WARP_SIZE];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.y == 0) {
+ KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
+ KQ_sum_shared[j][threadIdx.x] = 0.0f;
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.x == 0) {
+ KQ_max_shared[j][threadIdx.y] = KQ_max[j];
+ }
+ }
+ __syncthreads();
+
+#pragma unroll
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+ if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z)) {
+ break;
+ }
+
+ float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
+ kqmax_new = warp_reduce_max(kqmax_new);
+ const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
+ KQ_max[j_VKQ] = kqmax_new;
+
+#ifdef V_DOT2_F32_F16_AVAILABLE
+ half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
+
+ const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale);
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
+
+ ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
+ }
+#else
+ float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
+
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
+ VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
+
+ ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
+ ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
+ }
+#endif // V_DOT2_F32_F16_AVAILABLE
+
+ KQ_sum[j_VKQ] *= kqmax_scale;
+ KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
+ if (threadIdx.x == 0) {
+ KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
+ }
+
+ __syncthreads();
+
+ if (nthreads <= D || tid < D) {
+ KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
+ KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
+
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += nthreads) {
+ float dst_val = 0;
+#pragma unroll
+ for (int w = 0; w < nwarps; ++w) {
+#pragma unroll
+ for (int v = 0; v < V_cols_per_iter; ++v) {
+ dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
+ }
+ }
+ if (gridDim.y == 1) {
+ dst_val /= KQ_sum[j_VKQ];
+ }
+ dst[(((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
+ }
+ }
+
+ if (j_VKQ < ncols-1) {
+ __syncthreads();
+ }
+
+ }
+
+ if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z))) {
+ dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
+ }
+#else
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
+void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+ const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
+ const int nwarps = nthreads / WARP_SIZE;
+ fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
+ const bool need_f16_K = type_K == GGML_TYPE_F16;
+ const bool need_f16_V = type_V == GGML_TYPE_F16;
+ constexpr size_t nbytes_shared = 0;
+ launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+}
+
+template <int D, ggml_type type_K, ggml_type type_V>
+void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ if (Q->ne[1] == 1) {
+ constexpr int cols_per_block = 1;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
+ return;
+ }
+
+ constexpr int cols_per_block = 2;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
+}
+
+#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
+ template void ggml_cuda_flash_attn_ext_vec_case \
+ <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
+
+#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
+
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu
new file mode 100644
index 0000000..8694fd0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cu
@@ -0,0 +1,675 @@
+// Old and deprecated WMMA FlashAttention implementation.
+// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
+// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
+
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-wmma-f16.cuh"
+
+#ifdef GGML_USE_WMMA_FATTN
+#if !defined(GGML_USE_HIP)
+#include <mma.h>
+#if defined(GGML_USE_MUSA)
+namespace wmma = mtmusa::wmma;
+#else // GGML_USE_MUSA
+namespace wmma = nvcuda::wmma;
+#endif // GGML_USE_MUSA
+#elif defined(GGML_USE_HIP)
+#include <rocwmma/rocwmma.hpp>
+namespace wmma = rocwmma;
+#endif // !defined(GGML_USE_HIP)
+#endif // GGML_USE_WMMA_FATTN
+
+// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
+template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
+__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void flash_attn_ext_f16(
+ const char * __restrict__ Q,
+ const char * __restrict__ K,
+ const char * __restrict__ V,
+ const char * __restrict__ mask,
+ const char * __restrict__ sinks,
+ const int * __restrict__ KV_max,
+ float * __restrict__ dst,
+ float2 * __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
+ const int32_t nb01, const int32_t nb02, const int32_t nb03,
+ const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
+ const int32_t nb11, const int32_t nb12, const int64_t nb13,
+ const int32_t nb21, const int32_t nb22, const int64_t nb23,
+ const int32_t ne31, const int32_t ne32, const int32_t ne33,
+ const int32_t nb31, const int32_t nb32, const int64_t nb33) {
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
+
+ static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
+ static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
+ constexpr int frag_m = ncols == 8 ? 32 : 16;
+ constexpr int frag_n = ncols == 8 ? 8 : 16;
+ static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
+
+ constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
+ constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
+ static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
+
+ // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
+ constexpr int D_padded = D + 8;
+ constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
+ constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
+
+ const int sequence = blockIdx.z / ne02;
+ const int head = blockIdx.z - sequence*ne02;
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
+ const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
+ const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const half2 * mask2 = (const half2 *) maskh;
+ const float * sinksf = (const float *) sinks;
+
+ const int stride_Q = nb01 / sizeof(float);
+ const int stride_KV = nb11 / sizeof(half);
+
+ const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
+ const half slopeh = __float2half(slopef);
+ const half2 slope2 = make_half2(slopef, slopef);
+
+ const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
+ frag_b Q_b[D/16][ncols/frag_n];
+
+ // A single buffer for temporarily holding tiles of KQ and VKQ parts:
+ constexpr int mem_KQ = ncols*kqs_padded*kqar;
+ constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
+ __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
+ float * KQ_f = (float *) KQ;
+ half2 * KQ2 = (half2 *) KQ;
+
+ float KQ_rowsum_f[ncols/nwarps] = {0.0f};
+ float KQ_max_f[ncols/nwarps];
+ float KQ_max_scale_f[ncols/nwarps] = {0.0f};
+
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ KQ_max_f[j] = -FLT_MAX/2.0f;
+ }
+
+ half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+ half2 KQ_max_h2[ncols/nwarps];
+ half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
+
+#pragma unroll
+ for (int j = 0; j < ncols/nwarps; ++j) {
+ KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
+ }
+
+ __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
+ half2 * VKQ2 = (half2 *) VKQ;
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D/2 && i >= D/2) {
+ break;
+ }
+ VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
+ }
+ }
+
+ // Convert Q to half and apply scale, temporarily store in KQ:
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D && i >= D) {
+ break;
+ }
+ KQ[j*D_padded + i] = ic0 + j < int(ne01.z) ? Q_f[j*stride_Q + i] * scale : 0.0f;
+ }
+ }
+
+ __syncthreads();
+
+ // Load Q into tensor core fragments/registers since it will be used frequently:
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += 16) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+ }
+ }
+
+ __syncthreads();
+
+ // Iterate over ne11 == previous tokens:
+ const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+ for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
+ // Calculate tile of KQ:
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
+ frag_c_KQ KQ_c[ncols/frag_n];
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
+ }
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
+ frag_a_K K_a;
+ wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+ }
+ }
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
+ }
+ }
+
+ __syncthreads();
+
+ // Calculate softmax for each KQ column using the current max. value.
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (std::is_same<KQ_acc_t, float>::value) {
+ float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
+ const int k = k0 + threadIdx.x;
+
+ KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
+
+ if (use_logit_softcap) {
+ KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
+ }
+ }
+
+ float KQ_max_new = KQ_max_f[j0/nwarps];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
+ const int k = k0 + threadIdx.x;
+
+ KQ_f_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ?
+ __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size] + FATTN_KQ_MAX_OFFSET);
+ }
+ KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
+
+ const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
+ KQ_max_scale_f[j0/nwarps] = expf(diff);
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+ KQ_max_scale_f[j0/nwarps] = 0.0f;
+ }
+ KQ_max_f[j0/nwarps] = KQ_max_new;
+
+ float KQ_rowsum_add = 0.0f;
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
+ const int k = k0 + threadIdx.x;
+
+ const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
+ KQ_f_tmp[k0/warp_size] = expf(diff);
+ if (diff <= SOFTMAX_FTZ_THRESHOLD) {
+ KQ_f_tmp[k0/warp_size] = 0.0f;
+ }
+ KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
+ }
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
+ } else {
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
+ const int k = k0 + threadIdx.x;
+
+ KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
+
+ if (use_logit_softcap) {
+ // There is no dedicated tangens hyperbolicus function for half2.
+ KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
+ KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
+ /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
+
+ KQ2_tmp[k0/warp_size] *= logit_softcap_2;
+ }
+ }
+
+ half2 KQ_max_new = KQ_max_h2[j0/nwarps];
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
+ const int k = k0 + threadIdx.x;
+
+ KQ2_tmp[k0/warp_size] += mask && ic0 + j < int(ne01.z) ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+ KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
+ }
+ KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+ const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
+ KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+ *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
+ KQ_max_h2[j0/nwarps] = KQ_max_new;
+
+ half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
+ const int k = k0 + threadIdx.x;
+
+ const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
+ KQ2_tmp[k0/warp_size] = h2exp(diff);
+ const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
+ *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
+ KQ_rowsum_add += KQ2_tmp[k0/warp_size];
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
+ }
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
+
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
+ KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
+ }
+ }
+
+ __syncthreads();
+
+ frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+ wmma::load_matrix_sync(
+ KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
+ KQ + j0*(kqar*kqs_padded) + k,
+ kqar*kqs_padded);
+ }
+ }
+
+ frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
+ }
+
+#pragma unroll
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
+ const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
+
+ frag_a_V v_a;
+ wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+#pragma unroll
+ for (int j = 0; j < ncols/frag_n; ++j) {
+ wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+ }
+ }
+ }
+
+ __syncthreads();
+
+ const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += frag_n) {
+ wmma::store_matrix_sync(
+ KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+ VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
+ D_padded, wmma::mem_col_major);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ half2 VKQ_scale;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
+ } else {
+ VKQ_scale = KQ_max_scale_h2[j0/nwarps];
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D/2 && i >= D/2) {
+ break;
+ }
+
+ half2 VKQ_add = make_half2(0.0f, 0.0f);
+#pragma unroll
+ for (int l = 0; l < VKQ_ratio; ++l) {
+ VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
+ }
+ VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
+ }
+ }
+
+ __syncthreads();
+ }
+
+ // Apply attention sinks
+ if (sinksf && blockIdx.y == 0) {
+ const float sinkf = sinksf[head];
+ const half sinkh = __float2half(sinkf);
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (std::is_same<KQ_acc_t, float>::value) {
+ float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
+
+ const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
+ KQ_max_f[j0/nwarps] = kqmax_new;
+
+ KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
+
+ const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D/2 && i >= D/2) break;
+ VKQ2[j*(D_padded/2) + i] *= scale_h2;
+ }
+ } else {
+ half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
+ half kqmax_new = fmaxf(kqmax_old, sinkh);
+ KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
+
+ const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
+ const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
+
+ KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
+ const half val = hexp(sinkh - kqmax_new);
+ KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
+
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D/2 && i >= D/2) break;
+ VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
+ }
+ }
+ }
+
+ __syncthreads();
+ }
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j_VKQ = j0 + threadIdx.y;
+ if (ic0 + j_VKQ >= int(ne01.z)) {
+ return;
+ }
+
+ float KQ_rowsum_j;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
+ } else {
+ KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
+ }
+
+ const int j_dst_unrolled = ((sequence*int(ne01.z) + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+ if (i0 + warp_size > D && i >= D) {
+ break;
+ }
+ float dst_val = VKQ[j_VKQ*D_padded + i];
+ if (gridDim.y == 1) {
+ dst_val /= KQ_rowsum_j;
+ }
+ dst[j_dst_unrolled*D + i] = dst_val;
+ }
+
+ if (gridDim.y == 1 || threadIdx.x != 0) {
+ continue;
+ }
+
+ float2 dst_meta_val;
+ if (std::is_same<KQ_acc_t, float>::value) {
+ dst_meta_val.x = KQ_max_f[j0/nwarps];
+ } else {
+ dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
+ }
+ dst_meta_val.y = KQ_rowsum_j;
+ dst_meta[j_dst_unrolled] = dst_meta_val;
+ }
+#else
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))
+}
+
+constexpr int get_max_power_of_2(int x) {
+ return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
+}
+
+static_assert(get_max_power_of_2(1) == 1, "Test failed.");
+static_assert(get_max_power_of_2(2) == 2, "Test failed.");
+static_assert(get_max_power_of_2(4) == 4, "Test failed.");
+static_assert(get_max_power_of_2(6) == 2, "Test failed.");
+
+// Number of VKQ rows calculated in parallel:
+constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
+ return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
+}
+
+static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
+static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
+static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
+static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
+static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
+static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
+
+template <int D, int cols_per_block, typename KQ_acc_t>
+void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+
+ constexpr int nwarps = 4;
+
+ constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
+ const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
+
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ fattn_kernel_t fattn_kernel;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
+ } else {
+ constexpr bool use_logit_softcap = true;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
+ }
+ launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
+}
+
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
+
+ if (prec != GGML_PREC_DEFAULT) {
+ if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
+ constexpr int cols_per_block = 16;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+ } else {
+ constexpr int cols_per_block = 32;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+ break;
+ // case 256:
+ // ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
+ // break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+ }
+ return;
+ }
+
+#if !defined(GGML_USE_HIP)
+ if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
+ constexpr int cols_per_block = 8;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+ return;
+ }
+#endif // !defined(GGML_USE_HIP)
+
+ if (Q->ne[1] <= 32) {
+ constexpr int cols_per_block = 16;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+ return;
+ }
+
+ constexpr int cols_per_block = 32;
+ switch (Q->ne[0]) {
+ case 64:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
+ break;
+ case 80:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
+ break;
+ case 96:
+ ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
+ break;
+ case 112:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
+ break;
+ case 128:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
+ break;
+ case 256:
+ ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
new file mode 100644
index 0000000..cd3bfd4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn-wmma-f16.cuh
@@ -0,0 +1,51 @@
+#pragma once
+
+#include "common.cuh"
+
+#if defined(GGML_USE_MUSA)
+#define GGML_USE_WMMA_FATTN
+#endif // defined(GGML_USE_MUSA)
+
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#define GGML_USE_WMMA_FATTN
+#elif defined(CDNA)
+#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance"
+#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+#if defined(RDNA3)
+#define GGML_USE_WMMA_FATTN
+#endif // defined(RDNA3)
+#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#define GGML_USE_WMMA_FATTN
+#elif defined(RDNA4)
+#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance"
+#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
+// WMMA flash attention requires FP16 matrix instructions to be available for ggml code.
+static bool ggml_cuda_should_use_wmma_fattn(const int cc) {
+#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+ return false;
+#else
+ if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) ||
+ GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) {
+ return true;
+ } else if (GGML_CUDA_CC_IS_CDNA(cc)){
+#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+ return true;
+#else
+ return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0)
+ } else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+ return true;
+#else
+ return false;
+#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1
+ } else {
+ return false;
+ }
+#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN)
+}
+
+void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn.cu b/llama.cpp/ggml/src/ggml-cuda/fattn.cu
new file mode 100644
index 0000000..721edd9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn.cu
@@ -0,0 +1,482 @@
+#include "common.cuh"
+#include "fattn-common.cuh"
+#include "fattn-mma-f16.cuh"
+#include "fattn-tile.cuh"
+#include "fattn-vec.cuh"
+#include "fattn-wmma-f16.cuh"
+#include "fattn.cuh"
+
+template <int DKQ, int DV, int ncols2>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const ggml_tensor * Q = dst->src[0];
+
+ if constexpr (ncols2 <= 8) {
+ if (turing_mma_available(cc) && Q->ne[1] <= 8/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+ }
+
+ if constexpr (ncols2 <= 16) {
+ if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+ }
+
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
+}
+
+template <int DKQ, int DV>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
+ // are put into the template specialization without GQA optimizations.
+ bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ for (const ggml_tensor * t : {Q, K, V, mask}) {
+ if (t == nullptr || ggml_is_quantized(t->type)) {
+ continue;
+ }
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (t->nb[i] % 16 != 0) {
+ use_gqa_opt = false;
+ break;
+ }
+ }
+ }
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+ // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
+ if (cc == GGML_CUDA_CC_VOLTA) {
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio > 4) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio > 2) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio > 1) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
+ return;
+ }
+
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+}
+
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ switch (Q->ne[0]) {
+ case 64:
+ GGML_ASSERT(V->ne[0] == 64);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64, 64>(ctx, dst);
+ break;
+ case 80:
+ GGML_ASSERT(V->ne[0] == 80);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80, 80>(ctx, dst);
+ break;
+ case 96:
+ GGML_ASSERT(V->ne[0] == 96);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96, 96>(ctx, dst);
+ break;
+ case 112:
+ GGML_ASSERT(V->ne[0] == 112);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
+ break;
+ case 128:
+ GGML_ASSERT(V->ne[0] == 128);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
+ break;
+ case 256:
+ GGML_ASSERT(V->ne[0] == 256);
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
+ break;
+ case 576: {
+ // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ GGML_ASSERT(V->ne[0] == 512);
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ const bool use_gqa_opt = mask && max_bias == 0.0f;
+ GGML_ASSERT(use_gqa_opt);
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ if (gqa_ratio == 20) { // GLM 4.7 Flash
+ if (cc >= GGML_CUDA_CC_DGX_SPARK) {
+ if (Q->ne[1] <= 8) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ if (cc >= GGML_CUDA_CC_BLACKWELL) {
+ if (Q->ne[1] <= 4 && K->ne[1] >= 65536) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ if (Q->ne[1] <= 4) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ if (cc >= GGML_CUDA_CC_TURING) {
+ if (Q->ne[1] <= 4) {
+ if (K->ne[1] <= 16384) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
+ break;
+ }
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ break;
+ }
+ // Volta:
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ } else if (gqa_ratio % 16 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ }
+ } break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+#define FATTN_VEC_CASE(D, type_K, type_V) \
+ { \
+ const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
+ const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
+ if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
+ ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
+ return; \
+ } \
+ } \
+
+#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
+ FATTN_VEC_CASE( 64, type_K, type_V) \
+ FATTN_VEC_CASE(128, type_K, type_V) \
+ FATTN_VEC_CASE(256, type_K, type_V) \
+
+static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * Q = dst->src[0];
+ ggml_tensor * K = dst->src[1];
+ ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#else
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+ GGML_ABORT("fatal error");
+}
+
+// Best FlashAttention kernel for a specific GPU:
+enum best_fattn_kernel {
+ BEST_FATTN_KERNEL_NONE = 0,
+ BEST_FATTN_KERNEL_TILE = 200,
+ BEST_FATTN_KERNEL_VEC = 100,
+ BEST_FATTN_KERNEL_WMMA_F16 = 300,
+ BEST_FATTN_KERNEL_MMA_F16 = 400,
+};
+
+static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
+#ifndef FLASH_ATTN_AVAILABLE
+ GGML_UNUSED(device); GGML_UNUSED(dst);
+ return BEST_FATTN_KERNEL_NONE;
+#endif// FLASH_ATTN_AVAILABLE
+
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ // The effective batch size for the kernel can be increased by gqa_ratio.
+ // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
+ bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ for (const ggml_tensor * t : {Q, K, V, mask}) {
+ if (t == nullptr || ggml_is_quantized(t->type)) {
+ continue;
+ }
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (t->nb[i] % 16 != 0) {
+ gqa_opt_applies = false;
+ break;
+ }
+ }
+ }
+
+ const int cc = ggml_cuda_info().devices[device].cc;
+
+ switch (K->ne[0]) {
+ case 40:
+ case 64:
+ case 72:
+ case 80:
+ case 96:
+ case 128:
+ case 112:
+ case 256:
+ if (V->ne[0] != K->ne[0]) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ case 576:
+ if (V->ne[0] != 512) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ default:
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+#ifndef GGML_CUDA_FA_ALL_QUANTS
+ if (K->type != V->type) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+#endif // GGML_CUDA_FA_ALL_QUANTS
+
+ switch (K->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ break;
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+#ifndef GGML_CUDA_FA_ALL_QUANTS
+ return BEST_FATTN_KERNEL_NONE;
+#endif // GGML_CUDA_FA_ALL_QUANTS
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q8_0:
+ break;
+ default:
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+ if (mask && mask->ne[2] != 1) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+ // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
+ const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+ // If Turing tensor cores are available, use them:
+ if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ } else {
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ } else {
+ if (Q->ne[1] == 1) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ if (!gqa_opt_applies && Q->ne[1] == 1) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ return BEST_FATTN_KERNEL_MMA_F16;
+ }
+
+ if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) {
+ int gqa_ratio_eff = 1;
+ const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+ while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+ gqa_ratio_eff *= 2;
+ }
+ if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ if (Q->ne[1] * gqa_ratio_eff <= 16) {
+ return BEST_FATTN_KERNEL_TILE; // On Volta tensor cores are only faster for sufficiently large matrices.
+ }
+ return BEST_FATTN_KERNEL_MMA_F16;
+ }
+
+ // Use the WMMA kernel if possible:
+ if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
+ if (can_use_vector_kernel && Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ return BEST_FATTN_KERNEL_WMMA_F16;
+ }
+
+ if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) {
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (Q->ne[1] == 1) {
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ } else {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ int gqa_ratio_eff = 1;
+ const int ncols2_max = Q->ne[0] == 576 ? 16 : 8;
+ while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) {
+ gqa_ratio_eff *= 2;
+ }
+ if (Q->ne[1] * gqa_ratio_eff <= 8) {
+ return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized.
+ }
+ return BEST_FATTN_KERNEL_MMA_F16;
+ }
+
+ // If there are no tensor cores available, use the generic tile kernel:
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (Q->ne[1] == 1) {
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ } else {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ return BEST_FATTN_KERNEL_TILE;
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_set_device(ctx.device);
+ switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
+ case BEST_FATTN_KERNEL_NONE:
+ GGML_ABORT("fatal error");
+ case BEST_FATTN_KERNEL_TILE:
+ ggml_cuda_flash_attn_ext_tile(ctx, dst);
+ break;
+ case BEST_FATTN_KERNEL_VEC:
+ ggml_cuda_flash_attn_ext_vec(ctx, dst);
+ break;
+ case BEST_FATTN_KERNEL_WMMA_F16:
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+ break;
+ case BEST_FATTN_KERNEL_MMA_F16:
+ ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
+ break;
+ }
+}
+
+bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
+ return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/fattn.cuh b/llama.cpp/ggml/src/ggml-cuda/fattn.cuh
new file mode 100644
index 0000000..78705d5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fattn.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/fill.cu b/llama.cpp/ggml/src/ggml-cuda/fill.cu
new file mode 100644
index 0000000..739062c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fill.cu
@@ -0,0 +1,37 @@
+#include "fill.cuh"
+#include "convert.cuh"
+
+#define CUDA_FILL_BLOCK_SIZE 256
+
+template <typename T>
+static __global__ void fill_kernel(T * dst, const int64_t k, const T value) {
+ const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
+ if (i >= k) {
+ return;
+ }
+ dst[i] = value;
+}
+
+void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ void * dst_d = dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ float value;
+ memcpy(&value, dst->op_params, sizeof(float));
+
+ const int64_t k = ggml_nelements(dst);
+ const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
+
+ switch (dst->type) {
+ case GGML_TYPE_F32:
+ fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
+ break;
+ case GGML_TYPE_F16:
+ fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
+ break;
+ default:
+ GGML_ABORT("unsupported type");
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/fill.cuh b/llama.cpp/ggml/src/ggml-cuda/fill.cuh
new file mode 100644
index 0000000..8443c83
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/fill.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/getrows.cu b/llama.cpp/ggml/src/ggml-cuda/getrows.cu
new file mode 100644
index 0000000..2fab332
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/getrows.cu
@@ -0,0 +1,286 @@
+#include "getrows.cuh"
+#include "dequantize.cuh"
+#include "convert.cuh"
+
+template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static __global__ void k_get_rows(
+ const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+ /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
+ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+ /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+ const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
+
+ for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
+ for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+ const int i10 = blockIdx.x;
+ const int i11 = z / ne12; // TODO fastdiv
+ const int i12 = z % ne12;
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
+
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
+ const int y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ float2 v;
+ dequantize_kernel(src0_row, ib, iqs, v);
+
+ dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
+ dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+ }
+ }
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+ const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
+ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
+ /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
+ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
+ /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
+ const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
+
+ for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
+ for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
+ // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+ const int i10 = blockIdx.x;
+ const int i11 = z / ne12; // TODO fastdiv
+ const int i12 = z % ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
+ }
+ }
+}
+
+template<typename grad_t, typename dst_t>
+static __global__ void k_get_rows_back_float(
+ const grad_t * __restrict__ grad, const int32_t * __restrict__ rows, dst_t * __restrict__ dst, const int64_t ncols, const int64_t nrows_grad) {
+ const int col = blockIdx.x*blockDim.x + threadIdx.x;
+
+ if (col >= ncols) {
+ return;
+ }
+
+ const int dst_row = blockIdx.y*blockDim.y + threadIdx.y;
+
+ float sum = 0.0f;
+
+ for (int64_t i = 0; i < nrows_grad; ++i) {
+ if (rows[i] != dst_row) {
+ continue;
+ }
+ sum += grad[i*ncols + col];
+ }
+
+ dst[dst_row*ncols + col] = sum;
+}
+
+template<int qk, int qr, dequantize_kernel_t dq, typename dst_t>
+static void get_rows_cuda_q(
+ const void * src0_d, const int32_t * src1_d, dst_t * dst_d,
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
+ const size_t nb1, const size_t nb2, const size_t nb3,
+ cudaStream_t stream) {
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+ const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
+
+ // strides in elements
+ // const size_t s0 = nb0 / sizeof(dst_t);
+ const size_t s1 = nb1 / sizeof(dst_t);
+ const size_t s2 = nb2 / sizeof(dst_t);
+ const size_t s3 = nb3 / sizeof(dst_t);
+
+ const size_t s10 = nb10 / sizeof(int32_t);
+ const size_t s11 = nb11 / sizeof(int32_t);
+ const size_t s12 = nb12 / sizeof(int32_t);
+ // const size_t s13 = nb13 / sizeof(int32_t);
+
+ GGML_ASSERT(ne00 % 2 == 0);
+
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+ src0_d, src1_d, dst_d,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10,*/ ne11, ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+}
+
+template<typename src0_t, typename dst_t>
+static void get_rows_cuda_float(
+ const src0_t * src0_d, const int32_t * src1_d, dst_t * dst_d,
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
+ const size_t nb1, const size_t nb2, const size_t nb3,
+ cudaStream_t stream) {
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+ const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
+
+ // strides in elements
+ // const size_t s0 = nb0 / sizeof(dst_t);
+ const size_t s1 = nb1 / sizeof(dst_t);
+ const size_t s2 = nb2 / sizeof(dst_t);
+ const size_t s3 = nb3 / sizeof(dst_t);
+
+ const size_t s10 = nb10 / sizeof(int32_t);
+ const size_t s11 = nb11 / sizeof(int32_t);
+ const size_t s12 = nb12 / sizeof(int32_t);
+ // const size_t s13 = nb13 / sizeof(int32_t);
+
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+ src0_d, src1_d, dst_d,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10,*/ ne11, ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+}
+
+template <typename dst_t>
+static void ggml_cuda_get_rows_switch_src0_type(
+ const void * src0_d, const ggml_type src0_type, const int32_t * src1_d, dst_t * dst_d,
+ const int64_t ne00, const size_t nb01, const size_t nb02, const size_t nb03,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const size_t nb10, const size_t nb11, const size_t nb12,
+ const size_t nb1, const size_t nb2, const size_t nb3,
+ cudaStream_t stream) {
+ switch (src0_type) {
+ case GGML_TYPE_F16:
+ get_rows_cuda_float((const half *) src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_F32:
+ get_rows_cuda_float((const float *) src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_I32:
+ get_rows_cuda_float((const int32_t *) src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_BF16:
+ get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_Q4_0:
+ get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ get_rows_cuda_q<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ get_rows_cuda_q<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ get_rows_cuda_q<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ get_rows_cuda_q<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_d, dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ default:
+ // TODO: k-quants
+ GGML_ABORT("%s: unsupported src0 type: %s\n", __func__, ggml_type_name(src0_type));
+ break;
+ }
+}
+
+void get_rows_cuda(
+ const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
+ int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
+ int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
+ size_t nb1, size_t nb2, size_t nb3,
+ cudaStream_t stream) {
+ switch (dst_type) {
+ case GGML_TYPE_F32:
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (float *) dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_I32:
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (int32_t *) dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_F16:
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (half *) dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ case GGML_TYPE_BF16:
+ ggml_cuda_get_rows_switch_src0_type(src0_d, src0_type, src1_d, (nv_bfloat16 *) dst_d,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+ break;
+ default:
+ GGML_ABORT("%s: unsupported dst type: %s\n", __func__, ggml_type_name(dst_type));
+ break;
+ }
+}
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ GGML_ASSERT(ne13 == 1);
+
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
+
+ get_rows_cuda(src0->data, src0->type, (const int32_t *) src1->data, dst->data, dst->type,
+ ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream);
+}
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
+ const ggml_tensor * src1 = dst->src[1]; // src1 in forward pass
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const float * src0_d = (const float *) src0->data;
+ const int32_t * src1_d = (const int32_t *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(ne02*ne03 == 1);
+ GGML_ASSERT(ne12*ne13 == 1);
+ GGML_ASSERT(ne2*ne3 == 1);
+
+ const dim3 block_dims(CUDA_GET_ROWS_BACK_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BACK_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BACK_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, ne1, 1);
+
+ k_get_rows_back_float<<<block_nums, block_dims, 0, stream>>>(src0_d, src1_d, dst_d, ne00, ne10);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/getrows.cuh b/llama.cpp/ggml/src/ggml-cuda/getrows.cuh
new file mode 100644
index 0000000..3c5bea5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/getrows.cuh
@@ -0,0 +1,15 @@
+#include "common.cuh"
+
+#define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_GET_ROWS_BACK_BLOCK_SIZE 256
+
+void get_rows_cuda(
+ const void * src0_d, ggml_type src0_type, const int32_t * src1_d, void * dst_d, ggml_type dst_type,
+ int64_t ne00, size_t nb01, size_t nb02, size_t nb03,
+ int64_t ne10, int64_t ne11, int64_t ne12, size_t nb10, size_t nb11, size_t nb12,
+ size_t nb1, size_t nb2, size_t nb3,
+ cudaStream_t stream);
+
+void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_get_rows_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu b/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu
new file mode 100644
index 0000000..b163468
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -0,0 +1,5118 @@
+#include "ggml-cuda.h"
+#include "ggml-impl.h"
+#include "ggml-backend-impl.h"
+
+#include "ggml-cuda/common.cuh"
+#include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/add-id.cuh"
+#include "ggml-cuda/arange.cuh"
+#include "ggml-cuda/argmax.cuh"
+#include "ggml-cuda/argsort.cuh"
+#include "ggml-cuda/binbcast.cuh"
+#include "ggml-cuda/clamp.cuh"
+#include "ggml-cuda/concat.cuh"
+#include "ggml-cuda/conv-transpose-1d.cuh"
+#include "ggml-cuda/conv2d.cuh"
+#include "ggml-cuda/conv2d-dw.cuh"
+#include "ggml-cuda/conv2d-transpose.cuh"
+#include "ggml-cuda/convert.cuh"
+#include "ggml-cuda/count-equal.cuh"
+#include "ggml-cuda/cpy.cuh"
+#include "ggml-cuda/cross-entropy-loss.cuh"
+#include "ggml-cuda/cumsum.cuh"
+#include "ggml-cuda/diagmask.cuh"
+#include "ggml-cuda/diag.cuh"
+#include "ggml-cuda/fattn.cuh"
+#include "ggml-cuda/getrows.cuh"
+#include "ggml-cuda/im2col.cuh"
+#include "ggml-cuda/mmf.cuh"
+#include "ggml-cuda/mmq.cuh"
+#include "ggml-cuda/mmvf.cuh"
+#include "ggml-cuda/mmvq.cuh"
+#include "ggml-cuda/norm.cuh"
+#include "ggml-cuda/opt-step-adamw.cuh"
+#include "ggml-cuda/opt-step-sgd.cuh"
+#include "ggml-cuda/out-prod.cuh"
+#include "ggml-cuda/pad.cuh"
+#include "ggml-cuda/pool2d.cuh"
+#include "ggml-cuda/quantize.cuh"
+#include "ggml-cuda/rope.cuh"
+#include "ggml-cuda/roll.cuh"
+#include "ggml-cuda/scale.cuh"
+#include "ggml-cuda/softcap.cuh"
+#include "ggml-cuda/softmax.cuh"
+#include "ggml-cuda/ssm-conv.cuh"
+#include "ggml-cuda/ssm-scan.cuh"
+#include "ggml-cuda/sum.cuh"
+#include "ggml-cuda/sumrows.cuh"
+#include "ggml-cuda/top-k.cuh"
+#include "ggml-cuda/mean.cuh"
+#include "ggml-cuda/tsembd.cuh"
+#include "ggml-cuda/topk-moe.cuh"
+#include "ggml-cuda/unary.cuh"
+#include "ggml-cuda/upscale.cuh"
+#include "ggml-cuda/wkv.cuh"
+#include "ggml-cuda/gla.cuh"
+#include "ggml-cuda/set.cuh"
+#include "ggml-cuda/set-rows.cuh"
+#include "ggml-cuda/pad_reflect_1d.cuh"
+#include "ggml-cuda/solve_tri.cuh"
+#include "ggml-cuda/tri.cuh"
+#include "ggml-cuda/cumsum.cuh"
+#include "ggml-cuda/fill.cuh"
+#include "ggml.h"
+
+#include <algorithm>
+#include <array>
+#include <atomic>
+#include <charconv>
+#include <cinttypes>
+#include <condition_variable>
+#include <cstddef>
+#include <cstdint>
+#include <cfloat>
+#include <initializer_list>
+#include <limits>
+#include <map>
+#include <memory>
+#include <mutex>
+#include <cstdarg>
+#include <cstdio>
+#include <cstdlib>
+#include <string>
+#include <vector>
+#include <unordered_set>
+
+static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
+
+[[noreturn]]
+void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
+ int id = -1; // in case cudaGetDevice fails
+ (void)cudaGetDevice(&id);
+
+ GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
+ GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
+ GGML_LOG_ERROR(" %s\n", stmt);
+ // abort with GGML_ABORT to get a stack trace
+ GGML_ABORT(GGML_CUDA_NAME " error");
+}
+
+// this is faster on Windows
+// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
+void ggml_cuda_set_device(int device) {
+ int current_device;
+ CUDA_CHECK(cudaGetDevice(&current_device));
+
+ if (device == current_device) {
+ return;
+ }
+
+ CUDA_CHECK(cudaSetDevice(device));
+}
+
+int ggml_cuda_get_device() {
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ return id;
+}
+
+static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
+ ggml_cuda_set_device(device);
+ cudaError_t err;
+ if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
+ err = cudaMallocManaged(ptr, size);
+#if defined(GGML_USE_HIP)
+ if (err == hipSuccess) {
+ CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+ }
+
+ // fall back to cudaMalloc if not supported (e.g. on Windows)
+ if (err == hipErrorNotSupported) {
+ static bool warned_unsupported = false;
+ if (!warned_unsupported) {
+ GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n");
+ warned_unsupported = true;
+ }
+
+ err = cudaMalloc(ptr, size);
+ }
+#endif // defined(GGML_USE_HIP)
+ } else {
+ err = cudaMalloc(ptr, size);
+ }
+ return err;
+}
+
+#if defined(GGML_USE_HIP)
+static int ggml_cuda_parse_id(char devName[]) {
+ // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
+ // these values are not stable so this is susceptible to breakage
+ // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
+ int archMajor = 0x0;
+ int archMinor = 0x0;
+ int archNum = GGML_CUDA_CC_OFFSET_AMD;
+ int archLen = strlen(devName);
+ char archName[archLen + 1];
+
+ // strip leading 'gfx' while copying into our buffer
+ if (archLen > 3) {
+ strcpy(archName, &devName[3]);
+ archLen -= 3;
+ }
+
+ // trim trailing :xnack- or :sramecc- statuses
+ archLen = strcspn(archName, ":");
+ archName[archLen] = '\0';
+
+ // tease out the version information
+ if (archLen > 8) {
+ // versions labeled generic use '-' as delimiter
+ // strip the trailing "-generic" then iterate through what remains
+ if ((strstr(archName, "-generic"))) {
+ archName[archLen - 8] = '\0';
+ char * pch;
+ if ((pch = strtok(archName, "-"))) {
+ archMajor = (int)strtoul(pch, 0, 16);
+ if ((pch = strtok(NULL, "-"))) {
+ archMinor = 0x10 * (int)strtoul(pch, 0, 16);
+ }
+ }
+ }
+ } else if (archLen >= 3) {
+ // last two digits should be the minor * 0x10 + stepping
+ archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
+ archName[archLen - 2] = '\0';
+
+ // only the major version remains
+ archMajor = (int)strtoul(archName, 0, 16);
+ }
+ archNum += archMajor * 0x100;
+ archNum += archMinor;
+ return archNum;
+}
+#endif // defined(GGML_USE_HIP)
+
+static ggml_cuda_device_info ggml_cuda_init() {
+ ggml_cuda_device_info info = {};
+
+ cudaError_t err = cudaGetDeviceCount(&info.device_count);
+ if (err != cudaSuccess) {
+ GGML_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
+ return info;
+ }
+
+ GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
+
+ int64_t total_vram = 0;
+ GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
+
+ std::vector<std::pair<int, std::string>> turing_devices_without_mma;
+ for (int id = 0; id < info.device_count; ++id) {
+ int device_vmm = 0;
+
+#if defined(GGML_USE_VMM)
+ CUdevice device;
+ CU_CHECK(cuDeviceGet(&device, id));
+ CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
+
+ if (device_vmm) {
+ CUmemAllocationProp alloc_prop = {};
+ alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+ alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ alloc_prop.location.id = id;
+ CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
+ }
+#endif // defined(GGML_USE_VMM)
+ info.devices[id].vmm = !!device_vmm;
+
+ cudaDeviceProp prop;
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
+
+ info.default_tensor_split[id] = total_vram;
+ total_vram += prop.totalGlobalMem;
+ info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
+ info.devices[id].nsm = prop.multiProcessorCount;
+ info.devices[id].smpb = prop.sharedMemPerBlock;
+ info.devices[id].warp_size = prop.warpSize;
+
+#ifndef GGML_USE_MUSA
+ int supports_coop_launch = 0;
+ CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
+ info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
+#else
+ info.devices[id].supports_cooperative_launch = false;
+#endif // !(GGML_USE_MUSA)
+#if defined(GGML_USE_HIP)
+ info.devices[id].smpbo = prop.sharedMemPerBlock;
+
+ info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
+ if ((info.devices[id].cc & 0xff00) == 0x0) {
+ GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s cc %d.%d\n",
+ id, prop.name, prop.gcnArchName, prop.major, prop.minor);
+
+ // Fallback to prop.major and prop.minor
+ if (prop.major > 0) {
+ info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
+ info.devices[id].cc += prop.minor * 0x10;
+ }
+ }
+ GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
+ id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
+ device_vmm ? "yes" : "no", prop.warpSize);
+#elif defined(GGML_USE_MUSA)
+ // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
+ info.devices[id].warp_size = 32;
+ info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+ info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
+ info.devices[id].cc += prop.minor * 0x10;
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+#else
+ info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
+ info.devices[id].cc = 100*prop.major + 10*prop.minor;
+ GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
+ id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
+ std::string device_name(prop.name);
+ if (device_name == "NVIDIA GeForce MX450") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ } else if (device_name == "NVIDIA GeForce MX550") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
+ turing_devices_without_mma.push_back({ id, device_name });
+ }
+
+ // Temporary performance fix:
+ // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
+ // TODO: Check for future drivers the default scheduling strategy and
+ // remove this call again when cudaDeviceScheduleSpin is default.
+ if (prop.major == 12 && prop.minor == 1) {
+ CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
+ }
+
+#endif // defined(GGML_USE_HIP)
+ }
+
+ if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
+ GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
+ for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
+ GGML_LOG_INFO(
+ " Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
+ }
+ GGML_LOG_INFO(
+ "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
+ }
+
+ for (int id = 0; id < info.device_count; ++id) {
+ info.default_tensor_split[id] /= total_vram;
+ }
+
+ // configure logging to stdout
+ // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
+
+ return info;
+}
+
+const ggml_cuda_device_info & ggml_cuda_info() {
+ static ggml_cuda_device_info info = ggml_cuda_init();
+ return info;
+}
+
+// #define DEBUG_CUDA_MALLOC
+
+// buffer pool for cuda (legacy)
+struct ggml_cuda_pool_leg : public ggml_cuda_pool {
+ static const int MAX_BUFFERS = 256;
+
+ int device;
+ struct ggml_cuda_buffer {
+ void * ptr = nullptr;
+ size_t size = 0;
+ };
+
+ ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
+ size_t pool_size = 0;
+
+ explicit ggml_cuda_pool_leg(int device) :
+ device(device) {
+ }
+
+ ~ggml_cuda_pool_leg() {
+ ggml_cuda_set_device(device);
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cuda_buffer & b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+ CUDA_CHECK(cudaFree(b.ptr));
+ pool_size -= b.size;
+ }
+ }
+ GGML_ASSERT(pool_size == 0);
+ }
+
+ void * alloc(size_t size, size_t * actual_size) override {
+#ifdef DEBUG_CUDA_MALLOC
+ int nnz = 0;
+ size_t max_size = 0;
+#endif
+ size_t best_diff = 1ull << 36;
+ int ibest = -1;
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cuda_buffer& b = buffer_pool[i];
+ if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+ ++nnz;
+ if (b.size > max_size) max_size = b.size;
+#endif
+ if (b.size >= size) {
+ size_t diff = b.size - size;
+ if (diff < best_diff) {
+ best_diff = diff;
+ ibest = i;
+ if (!best_diff) {
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ }
+ }
+ }
+ }
+ if (ibest >= 0) {
+ ggml_cuda_buffer& b = buffer_pool[ibest];
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ void * ptr;
+ size_t look_ahead_size = (size_t) (1.05 * size);
+ look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
+ *actual_size = look_ahead_size;
+ pool_size += look_ahead_size;
+#ifdef DEBUG_CUDA_MALLOC
+ GGML_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
+ (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
+#endif
+ return ptr;
+ }
+
+ void free(void * ptr, size_t size) override {
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
+ ggml_cuda_buffer& b = buffer_pool[i];
+ if (b.ptr == nullptr) {
+ b.ptr = ptr;
+ b.size = size;
+ return;
+ }
+ }
+ GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(cudaFree(ptr));
+ pool_size -= size;
+ }
+};
+
+// pool with virtual memory
+#if defined(GGML_USE_VMM)
+struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
+ static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
+
+ int device;
+ CUdeviceptr pool_addr = 0;
+ size_t pool_used = 0;
+ size_t pool_size = 0;
+ size_t granularity;
+#if defined(GGML_USE_HIP)
+ std::vector<std::pair<CUdeviceptr, size_t>> mappings;
+#endif
+
+ explicit ggml_cuda_pool_vmm(int device) :
+ device(device),
+ granularity(ggml_cuda_info().devices[device].vmm_granularity) {
+ }
+
+ ~ggml_cuda_pool_vmm() {
+ if (pool_addr != 0) {
+#if defined(GGML_USE_HIP)
+ // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
+ for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
+ CU_CHECK(cuMemUnmap(mapping.first, mapping.second));
+ }
+#else
+ CU_CHECK(cuMemUnmap(pool_addr, pool_size));
+#endif
+ CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
+ }
+ }
+
+ void * alloc(size_t size, size_t * actual_size) override {
+ // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
+ const size_t alignment = 128;
+ size = alignment * ((size + alignment - 1) / alignment);
+
+ size_t avail = pool_size - pool_used;
+
+ if (size > avail) {
+ // round up to the next multiple of the granularity
+ size_t reserve_size = size - avail;
+ reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
+
+ GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
+
+ // allocate more physical memory
+ CUmemAllocationProp prop = {};
+ prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
+ prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ prop.location.id = device;
+ CUmemGenericAllocationHandle handle;
+ CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
+
+ // reserve virtual address space (if not already reserved)
+ if (pool_addr == 0) {
+ CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
+ }
+
+ // map at the end of the pool
+ CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
+ CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
+#if defined(GGML_USE_HIP)
+ mappings.push_back({start_ptr, reserve_size});
+#endif
+
+ // the memory allocation handle is no longer needed after mapping
+ CU_CHECK(cuMemRelease(handle));
+
+ // set access
+ CUmemAccessDesc access = {};
+ access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
+ access.location.id = device;
+ access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
+ CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
+
+ // add to the pool
+ pool_size += reserve_size;
+
+ //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
+ // device, (unsigned long long) (pool_size/1024/1024),
+ // (unsigned long long) (reserve_size/1024/1024));
+ }
+
+ GGML_ASSERT(pool_addr != 0);
+
+ void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));
+ *actual_size = size;
+ pool_used += size;
+
+#ifdef DEBUG_CUDA_MALLOC
+ printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+ return ptr;
+ }
+
+ void free(void * ptr, size_t size) override {
+#ifdef DEBUG_CUDA_MALLOC
+ printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
+#endif
+
+ pool_used -= size;
+
+ // all deallocations must be in reverse order of the allocations
+ GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
+ }
+};
+#endif // defined(GGML_USE_VMM)
+
+std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device,
+ [[maybe_unused]] int stream_no) {
+#if defined(GGML_USE_VMM)
+ if (ggml_cuda_info().devices[device].vmm) {
+ return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
+ }
+#endif // defined(GGML_USE_VMM)
+ return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
+}
+
+// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
+// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
+
+static std::mutex ggml_cuda_lock;
+static std::condition_variable ggml_cuda_lock_cv;
+static std::atomic<int> ggml_cuda_lock_counter;
+
+ggml_backend_cuda_context::~ggml_backend_cuda_context() {
+ std::unique_lock<std::mutex> lock(ggml_cuda_lock);
+ ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
+
+ if (copy_event != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(copy_event));
+ }
+ for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
+ for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
+ if (streams[i][j] != nullptr) {
+ CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
+ }
+ }
+ if (cublas_handles[i] != nullptr) {
+ CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
+ }
+ }
+}
+
+
+// cuda buffer
+
+struct ggml_backend_cuda_buffer_context {
+ int device;
+ void * dev_ptr = nullptr;
+ std::string name;
+
+ ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
+ device(device), dev_ptr(dev_ptr),
+ name(GGML_CUDA_NAME + std::to_string(device)) {
+ }
+
+ ~ggml_backend_cuda_buffer_context() {
+ CUDA_CHECK(cudaFree(dev_ptr));
+ }
+};
+
+static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+ delete ctx;
+}
+
+static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
+ return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer;
+}
+
+static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+ return ctx->dev_ptr;
+}
+
+static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ if (tensor->view_src != NULL) {
+ assert(tensor->view_src->buffer->buft == buffer->buft);
+ return GGML_STATUS_SUCCESS;
+ }
+
+ if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+ // initialize padding to 0 to avoid possible NaN values
+ const size_t original_size = ggml_nbytes(tensor);
+ const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+
+ if (padded_size > original_size) {
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
+ }
+ }
+ return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
+ if (ggml_backend_buffer_is_cuda(src->buffer)) {
+ ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
+ ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
+ if (src_ctx->device == dst_ctx->device) {
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
+ } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+ return false;
+#else
+ CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
+#endif
+ }
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+ return true;
+ }
+ return false;
+
+ GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
+
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+}
+
+static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
+ /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cuda_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor,
+ /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor,
+ /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor,
+ /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor,
+ /* .clear = */ ggml_backend_cuda_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// cuda buffer type
+struct ggml_backend_cuda_buffer_type_context {
+ int device;
+ std::string name;
+};
+
+static const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+ return ctx->name.c_str();
+}
+
+static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+
+ ggml_cuda_set_device(buft_ctx->device);
+
+ void * dev_ptr;
+ cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
+ if (err != cudaSuccess) {
+ // clear the error
+ (void)cudaGetLastError();
+ GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
+ return nullptr;
+ }
+
+ ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
+
+ return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 128;
+
+ GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ size_t size = ggml_nbytes(tensor);
+ int64_t ne0 = tensor->ne[0];
+
+ if (ggml_is_quantized(tensor->type)) {
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor));
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+ }
+
+ return size;
+
+ GGML_UNUSED(buft);
+}
+
+static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_cuda_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
+ /* .is_host = */ NULL,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+
+ if (device >= ggml_backend_cuda_get_device_count()) {
+ return nullptr;
+ }
+
+ static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
+
+ static bool ggml_backend_cuda_buffer_type_initialized = false;
+
+ if (!ggml_backend_cuda_buffer_type_initialized) {
+ for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) {
+ ggml_backend_cuda_buffer_types[i] = {
+ /* .iface = */ ggml_backend_cuda_buffer_type_interface,
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i),
+ /* .context = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
+ };
+ }
+ ggml_backend_cuda_buffer_type_initialized = true;
+ }
+
+ return &ggml_backend_cuda_buffer_types[device];
+}
+
+// cuda split buffer
+
+static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
+ int64_t row_rounding = 0;
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+ continue;
+ }
+
+ const int cc = ggml_cuda_info().devices[id].cc;
+ row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
+ }
+ return row_rounding;
+}
+
+static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
+ const int64_t nrows = ggml_nrows(tensor);
+ const int64_t rounding = get_row_rounding(tensor_split);
+
+ *row_low = id == 0 ? 0 : nrows*tensor_split[id];
+ *row_low -= *row_low % rounding;
+
+ if (id == ggml_backend_cuda_get_device_count() - 1) {
+ *row_high = nrows;
+ } else {
+ *row_high = nrows*tensor_split[id + 1];
+ *row_high -= *row_high % rounding;
+ }
+}
+
+static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+ return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
+}
+
+struct ggml_backend_cuda_split_buffer_type_context {
+ int main_device;
+ std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
+ std::string name;
+};
+
+struct ggml_backend_cuda_split_buffer_context {
+ ~ggml_backend_cuda_split_buffer_context() {
+ for (ggml_tensor_extra_gpu * extra : tensor_extras) {
+ for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {
+ for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+ if (extra->events[id][is] != nullptr) {
+ CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
+ }
+ }
+ if (extra->data_device[id] != nullptr) {
+ CUDA_CHECK(cudaFree(extra->data_device[id]));
+ }
+ }
+ delete extra;
+ }
+ }
+
+ std::vector<ggml_tensor_extra_gpu *> tensor_extras;
+};
+
+
+static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+ delete ctx;
+}
+
+static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
+ // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
+ return (void *)0x1000;
+
+ GGML_UNUSED(buffer);
+}
+
+static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
+ GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
+
+ ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+
+ ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+ ctx->tensor_extras.push_back(extra);
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ // FIXME: do not crash if cudaMalloc fails
+ // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
+ ggml_cuda_set_device(id);
+ char * buf;
+ CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
+
+ // set padding to 0 to avoid possible NaN values
+ if (size > original_size) {
+ CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
+ }
+
+ extra->data_device[id] = buf;
+
+ for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
+ CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
+ }
+ }
+ tensor->extra = extra;
+ return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ // split tensors must always be set in their entirety at once
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
+
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+ const size_t nb1 = tensor->nb[1];
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ const size_t offset_split = row_low*nb1;
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ const char * buf_host = (const char *)data + offset_split;
+ CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+ }
+}
+
+static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ // split tensors must always be set in their entirety at once
+ GGML_ASSERT(offset == 0);
+ GGML_ASSERT(size == ggml_nbytes(tensor));
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
+
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
+
+ const int64_t ne0 = tensor->ne[0];
+ const size_t nb1 = tensor->nb[1];
+ ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ const size_t offset_split = row_low*nb1;
+ size_t size = ggml_nbytes_split(tensor, nrows_split);
+ const size_t original_size = size;
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+
+ char * buf_host = (char *)data + offset_split;
+ CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
+ }
+}
+
+static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+ GGML_UNUSED(buffer);
+ GGML_UNUSED(value);
+}
+
+static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
+ /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer,
+ /* .get_base = */ ggml_backend_cuda_split_buffer_get_base,
+ /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor,
+ /* .memset_tensor = */ NULL,
+ /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor,
+ /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor,
+ /* .cpy_tensor = */ NULL,
+ /* .clear = */ ggml_backend_cuda_split_buffer_clear,
+ /* .reset = */ NULL,
+};
+
+// cuda split buffer type
+
+static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+ ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
+
+ return ctx->name.c_str();
+}
+
+static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
+ // instead, we allocate them for each tensor separately in init_tensor
+ // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
+ // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
+ ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
+
+ return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);
+}
+
+static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+ return 128;
+
+ GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+ ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
+ GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
+
+ size_t total_size = 0;
+
+ const int64_t ne0 = tensor->ne[0];
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ int64_t row_low, row_high;
+ get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
+
+ int64_t nrows_split = row_high - row_low;
+ if (nrows_split == 0) {
+ continue;
+ }
+
+ total_size += ggml_nbytes_split(tensor, nrows_split);
+
+ // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
+ total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+ }
+ }
+
+ return total_size;
+}
+
+static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+ return false;
+
+ GGML_UNUSED(buft);
+}
+
+static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
+ /* .get_name = */ ggml_backend_cuda_split_buffer_type_get_name,
+ /* .alloc_buffer = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
+ /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
+};
+
+ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+
+ static std::map<std::pair<int, std::array<float, GGML_CUDA_MAX_DEVICES>>, struct ggml_backend_buffer_type> buft_map;
+
+ std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
+
+ bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
+ if (all_zero) {
+ tensor_split_arr = ggml_cuda_info().default_tensor_split;
+ } else {
+ float split_sum = 0.0f;
+ for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+ tensor_split_arr[i] = split_sum;
+ split_sum += tensor_split[i];
+ }
+ for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
+ tensor_split_arr[i] /= split_sum;
+ }
+ }
+
+ auto it = buft_map.find({main_device, tensor_split_arr});
+ if (it != buft_map.end()) {
+ return &it->second;
+ }
+ auto * ctx = new ggml_backend_cuda_split_buffer_type_context{
+ main_device,
+ tensor_split_arr,
+ GGML_CUDA_NAME + std::to_string(main_device) + "_Split",
+ };
+
+ struct ggml_backend_buffer_type buft {
+ /* .iface = */ ggml_backend_cuda_split_buffer_type_interface,
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device),
+ /* .context = */ ctx,
+ };
+
+ auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
+ return &result.first->second;
+}
+
+// host buffer type
+
+static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+ return GGML_CUDA_NAME "_Host";
+
+ GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
+}
+
+static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+ CUDA_CHECK(cudaFreeHost(buffer->context));
+}
+
+static void * ggml_cuda_host_malloc(size_t size) {
+ if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
+ return nullptr;
+ }
+
+ void * ptr = nullptr;
+ cudaError_t err = cudaMallocHost((void **) &ptr, size);
+ if (err != cudaSuccess) {
+ // clear the error
+ (void)cudaGetLastError();
+ GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
+ size / 1024.0 / 1024.0, cudaGetErrorString(err));
+ return nullptr;
+ }
+
+ return ptr;
+}
+
+static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+ void * ptr = ggml_cuda_host_malloc(size);
+
+ if (ptr == nullptr) {
+ // fallback to cpu buffer
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+ }
+
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+ buffer->buft = buft;
+ buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
+
+ return buffer;
+}
+
+ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
+ static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
+ /* .iface = */ {
+ /* .get_name = */ ggml_backend_cuda_host_buffer_type_name,
+ /* .alloc_buffer = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+ },
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0),
+ /* .context = */ nullptr,
+ };
+
+ return &ggml_backend_cuda_buffer_type_host;
+}
+
+//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
+// return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
+//}
+
+/// kernels
+
+typedef void (*ggml_cuda_op_mul_mat_t)(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
+
+#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
+#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
+#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
+
+#define MUL_MAT_SRC1_COL_STRIDE 128
+
+static cudaError_t ggml_cuda_cpy_tensor_2d(
+ void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
+
+ const char * src_ptr = (const char *) src->data;
+ char * dst_ptr = (char *) dst;
+
+ const int64_t ne0 = src->ne[0];
+ const int64_t nb0 = src->nb[0];
+ const int64_t nb1 = src->nb[1];
+ const int64_t nb2 = src->nb[2];
+ const int64_t nb3 = src->nb[3];
+ const enum ggml_type type = src->type;
+ const int64_t ts = ggml_type_size(type);
+ const int64_t bs = ggml_blck_size(type);
+ const int64_t i1_diff = i1_high - i1_low;
+
+ const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+ if (nb0 == ts && nb1 == ts*ne0/bs) {
+ return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
+ } else if (nb0 == ts) {
+ return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
+ } else {
+ for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+ const void * rx = (const void *) ((const char *) x + i1*nb1);
+ void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+ // pretend the row is a matrix with cols=1
+ cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
+ if (r != cudaSuccess) {
+ return r;
+ }
+ }
+ return cudaSuccess;
+ }
+}
+
+static void ggml_cuda_op_mul_mat_cublas(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ GGML_ASSERT(src0_dd_i != nullptr);
+ GGML_ASSERT(src1_ddf_i != nullptr);
+ GGML_ASSERT(dst_dd_i != nullptr);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne10 = src1->ne[0];
+
+ const int64_t ne0 = dst->ne[0];
+
+ const int64_t row_diff = row_high - row_low;
+
+ int id = ggml_cuda_get_device();
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // ldc == nrows of the matrix that cuBLAS writes into
+ int64_t ldc = id == ctx.device ? ne0 : row_diff;
+
+ const int cc = ggml_cuda_info().devices[id].cc;
+
+ const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
+ (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
+
+ const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
+
+ if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
+ ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
+ if (src1->type != GGML_TYPE_BF16) {
+ const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
+ GGML_ASSERT(to_bf16_cuda != nullptr);
+ size_t ne = src1_ncols*ne10;
+ src1_as_bf16.alloc(ne);
+ to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
+ }
+ const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
+ const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
+ ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
+
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+ CUBLAS_CHECK(
+ cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
+ src1_ptr, CUDA_R_16BF, ne10,
+ &beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
+ CUBLAS_COMPUTE_32F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
+ to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+ } else if (fast_fp16_hardware_available(cc) && use_fp16) {
+ // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
+ ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
+ if (src0->type != GGML_TYPE_F16) {
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+ size_t ne = row_diff*ne00;
+ src0_as_f16.alloc(ne);
+ to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
+ }
+ const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
+
+ ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
+ if (src1->type != GGML_TYPE_F16) {
+ const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+ GGML_ASSERT(to_fp16_cuda != nullptr);
+ size_t ne = src1_ncols*ne10;
+ src1_as_f16.alloc(ne);
+ to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
+ }
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+
+ if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+ CUBLAS_CHECK(
+ cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha, src0_ptr, CUDA_R_16F, ne00,
+ src1_ptr, CUDA_R_16F, ne10,
+ &beta, dst_dd_i, CUDA_R_32F, ldc,
+ CUBLAS_COMPUTE_32F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+ } else {
+ ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
+
+ const half alpha_f16 = 1.0f;
+ const half beta_f16 = 0.0f;
+
+ CUBLAS_CHECK(
+ cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
+ src1_ptr, CUDA_R_16F, ne10,
+ &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
+ CUBLAS_COMPUTE_16F,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+ to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+ }
+ } else {
+ ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
+ ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
+
+ if (src0->type != GGML_TYPE_F32) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+ GGML_ASSERT(to_fp32_cuda != nullptr);
+ src0_ddq_as_f32.alloc(row_diff*ne00);
+ to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
+ }
+ if (src1->type != GGML_TYPE_F32) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
+ GGML_ASSERT(to_fp32_cuda != nullptr);
+ src1_ddq_as_f32.alloc(src1_ncols*ne10);
+ to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
+ }
+
+ const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
+ const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+ CUBLAS_CHECK(
+ cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
+ row_diff, src1_ncols, ne10,
+ &alpha, src0_ddf_i, ne00,
+ src1_ddf1_i, ne10,
+ &beta, dst_dd_i, ldc));
+ }
+
+ GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
+}
+
+static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
+ static bool peer_access_enabled = false;
+
+ const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
+
+ if (peer_access_enabled == enable_peer_access) {
+ return;
+ }
+
+#ifdef NDEBUG
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ ggml_cuda_set_device(id);
+ CUDA_CHECK(cudaDeviceSynchronize());
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ ggml_cuda_set_device(id);
+
+ for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
+ if (id == id_other) {
+ continue;
+ }
+ if (id != main_device && id_other != main_device) {
+ continue;
+ }
+
+ int can_access_peer;
+ CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
+ if (can_access_peer) {
+ if (enable_peer_access) {
+ cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
+ if (err != cudaErrorPeerAccessAlreadyEnabled) {
+ CUDA_CHECK(err);
+ } else {
+ // reset the error
+ (void)cudaGetLastError();
+ }
+ } else {
+ cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
+ if (err != cudaErrorPeerAccessNotEnabled) {
+ CUDA_CHECK(err);
+ } else {
+ // reset the error
+ (void)cudaGetLastError();
+ }
+ }
+ }
+ }
+ }
+
+ ggml_cuda_set_device(main_device);
+#endif // NDEBUG
+
+ peer_access_enabled = enable_peer_access;
+
+ GGML_UNUSED(main_device);
+}
+
+static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
+ void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+ // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
+ cudaMemcpy3DPeerParms p = {};
+ p.dstDevice = dstDevice;
+ p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
+ p.srcDevice = srcDevice;
+ p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
+ p.extent = make_cudaExtent(width, height, 1);
+ return cudaMemcpy3DPeerAsync(&p, stream);
+#else
+ // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
+ GGML_UNUSED(dstDevice);
+ GGML_UNUSED(srcDevice);
+ return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+static void ggml_cuda_op_mul_mat(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
+ quantize_cuda_t quantize_src1) {
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ const int64_t ne12 = src1->ne[2];
+ const int64_t ne13 = src1->ne[3];
+ const int64_t nrows1 = ggml_nrows(src1);
+
+ const int64_t ne0 = dst->ne[0];
+ const int64_t ne1 = dst->ne[1];
+
+ // const int64_t nb10 = src1->nb[0];
+ const int64_t nb11 = src1->nb[1];
+ const int64_t nb12 = src1->nb[2];
+ const int64_t nb13 = src1->nb[3];
+
+ const int64_t nb2 = dst->nb[2];
+ const int64_t nb3 = dst->nb[3];
+
+ ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
+ ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ const int64_t i02_divisor = ne12 / ne02;
+ const int64_t i03_divisor = ne13 / ne03;
+
+ const size_t src0_ts = ggml_type_size(src0->type);
+ const size_t src0_bs = ggml_blck_size(src0->type);
+ const size_t q8_1_ts = sizeof(block_q8_1);
+ const size_t q8_1_bs = QK8_1;
+
+ const bool src0_is_contiguous = ggml_is_contiguous(src0);
+ const bool src1_is_contiguous = ggml_is_contiguous(src1);
+
+ const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
+ GGML_ASSERT(!(split && ne02 > 1));
+ GGML_ASSERT(!(split && ne03 > 1));
+ GGML_ASSERT(!(split && ne02 < ne12));
+ GGML_ASSERT(!(split && ne03 < ne13));
+
+ ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
+
+
+ std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
+ if (split) {
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+ tensor_split = buft_ctx->tensor_split;
+ }
+
+ struct dev_data {
+ int cc;
+
+ ggml_cuda_pool_alloc<char> src0_dd_alloc;
+ ggml_cuda_pool_alloc<float> src1_ddf_alloc;
+ ggml_cuda_pool_alloc<char> src1_ddq_alloc;
+ ggml_cuda_pool_alloc<float> dst_dd_alloc;
+
+ char * src0_dd = nullptr;
+ float * src1_ddf = nullptr; // float
+ char * src1_ddq = nullptr; // q8_1
+ float * dst_dd = nullptr;
+
+ int64_t row_low;
+ int64_t row_high;
+ };
+
+ dev_data dev[GGML_CUDA_MAX_DEVICES];
+
+ int used_devices = 0;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ dev[id].cc = ggml_cuda_info().devices[id].cc;
+
+ // by default, use all rows
+ dev[id].row_low = 0;
+ dev[id].row_high = ne01;
+
+ // for multi GPU, get the row boundaries from tensor split
+ // and round to mul_mat_q tile sizes
+ if (split) {
+ const int64_t rounding = get_row_rounding(tensor_split);
+
+ if (id != 0) {
+ dev[id].row_low = ne01*tensor_split[id];
+ if (dev[id].row_low < ne01) {
+ dev[id].row_low -= dev[id].row_low % rounding;
+ }
+ }
+
+ if (id != ggml_backend_cuda_get_device_count() - 1) {
+ dev[id].row_high = ne01*tensor_split[id + 1];
+ if (dev[id].row_high < ne01) {
+ dev[id].row_high -= dev[id].row_high % rounding;
+ }
+ }
+ }
+ }
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+ continue;
+ }
+
+ used_devices++;
+
+ const bool src1_on_device = id == src1_ctx->device;
+ const bool dst_on_device = id == dst_ctx->device;
+
+ ggml_cuda_set_device(id);
+ cudaStream_t stream = ctx.stream(id, 0);
+
+ if (src0_is_contiguous) {
+ dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
+ } else {
+ // If src0 is not contiguous it will be copied to a temporary buffer.
+ // This buffer needs to be cleared entirely because multiple regions will function as padding.
+ const size_t nbytes_data = ggml_nbytes(src0);
+ const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+ dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
+ CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
+ }
+
+ // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
+ if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
+ GGML_ASSERT(ggml_is_contiguously_allocated(src0));
+ GGML_ASSERT(!src0->view_src);
+ const size_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
+ const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+ CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
+ }
+
+ if (src1_on_device && src1_is_contiguous) {
+ dev[id].src1_ddf = (float *) src1->data;
+ } else {
+ dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
+ }
+
+ if (quantize_src1) {
+ size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+ src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
+ }
+ dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
+
+ if (src1_on_device && src1_is_contiguous) {
+ quantize_src1(
+ dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,
+ nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
+ src1_padded_col_size, ne11, ne12, ne13, stream);
+ CUDA_CHECK(cudaGetLastError());
+ }
+ }
+
+ if (dst_on_device) {
+ dev[id].dst_dd = (float *) dst->data;
+ } else {
+ const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
+ dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);
+ }
+ }
+
+ // if multiple devices are used they need to wait for the main device
+ // here an event is recorded that signals that the main device has finished calculating the input data
+ if (split && used_devices > 1) {
+ ggml_cuda_set_device(ctx.device);
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));
+ }
+
+ const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+ for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+ const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
+ const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
+
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
+ continue;
+ }
+
+ const bool src1_on_device = id == src1_ctx->device;
+ const bool dst_on_device = id == dst_ctx->device;
+ const int64_t row_diff = dev[id].row_high - dev[id].row_low;
+
+ ggml_cuda_set_device(id);
+ cudaStream_t stream = ctx.stream(id, is);
+
+ // wait for main GPU data if necessary
+ if (split && (id != ctx.device || is != 0)) {
+ CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));
+ }
+
+ for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+ const int64_t i03 = i0 / ne12;
+ const int64_t i02 = i0 % ne12;
+
+ size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+ src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
+ } else {
+ src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
+ }
+
+ // for split tensors the data begins at i0 == i0_offset_low
+ const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
+ char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
+ float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
+ char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
+ float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
+
+ // the main device memory buffer can be on VRAM scratch, with space for all partial results
+ // in that case an offset on dst_ddf_i is needed
+ if (id == ctx.device) {
+ dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
+ }
+
+ // copy src0, src1 to device if necessary
+ if (src1_is_contiguous) {
+ if (id != ctx.device) {
+ if (quantize_src1) {
+ char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
+ if (quantize_src1 == quantize_mmq_q8_1_cuda) {
+ const size_t pitch = ne11*sizeof(block_q8_1_mmq);
+ const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
+ const size_t height = src1_padded_col_size/(4*QK8_1);
+ CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
+ } else {
+ CUDA_CHECK(cudaMemcpyPeerAsync(
+ src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
+ }
+ } else {
+ float * src1_ddf_i_source = (float *) src1->data;
+ src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+ CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,
+ src1_ncols*ne10*sizeof(float), stream));
+ }
+ }
+ } else if (src1_on_device && !src1_is_contiguous) {
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+ src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+ } else {
+ GGML_ABORT("fatal error");
+ }
+
+ if (quantize_src1 && !src1_is_contiguous) {
+ quantize_src1(
+ src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
+ src1_padded_col_size, src1_ncols, 1, 1, stream);
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+ src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
+ }
+
+ // do the computation
+ op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+ dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ // copy dst to host or other device if necessary
+ if (!dst_on_device) {
+ void * dst_off_device = dst->data;
+ if (split) {
+ // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+ // dst is NOT transposed.
+ // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+ // Instead they need to be copied to the correct slice in ne0 = dst row index.
+ // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
+ CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
+ dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
+ } else {
+ float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+ GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+ dhf_dst_i += src1_col_0*ne0;
+ CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
+ }
+ }
+
+ // add event for the main device to wait on until other device is done
+ if (split && (id != ctx.device || is != 0)) {
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
+ }
+ }
+ }
+ }
+
+ // main device waits for all other devices to be finished
+ if (split && ggml_backend_cuda_get_device_count() > 1) {
+ int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+ is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;
+
+ ggml_cuda_set_device(ctx.device);
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ if (dev[id].row_low == dev[id].row_high) {
+ continue;
+ }
+ for (int64_t is = 0; is < is_max; ++is) {
+ CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));
+ }
+ }
+ }
+}
+
+static __global__ void k_compute_batched_ptrs(
+ const void * src0_as_f16, const void * src1_as_f16, char * dst,
+ const void ** ptrs_src, void ** ptrs_dst,
+ int64_t ne12, int64_t ne13,
+ int64_t ne23,
+ size_t nb02, size_t nb03,
+ size_t nb12, size_t nb13,
+ size_t nbd2, size_t nbd3,
+ int64_t r2, int64_t r3) {
+ const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+ const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
+
+ if (i13 >= ne13 || i12 >= ne12) {
+ return;
+ }
+
+ const int64_t i03 = i13 / r3;
+ const int64_t i02 = i12 / r2;
+
+ ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+ ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+ ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst + i12*nbd2 + i13*nbd3;
+}
+
+// Type traits for mapping ggml types to CUDA/cuBLAS types
+template<ggml_type T>
+struct batched_mul_mat_traits;
+
+template<>
+struct batched_mul_mat_traits<GGML_TYPE_F32> {
+ using cuda_type = float;
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
+ static inline const cudaDataType_t data_type = CUDA_R_32F;
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
+ static inline const float alpha = 1.0f;
+ static inline const float beta = 0.0f;
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
+ static inline const void* get_beta() { static const float val = beta; return &val; }
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
+};
+
+template<>
+struct batched_mul_mat_traits<GGML_TYPE_BF16> {
+ using cuda_type = nv_bfloat16;
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
+ static inline const cudaDataType_t data_type = CUDA_R_16BF;
+ static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
+ static inline const float alpha = 1.0f;
+ static inline const float beta = 0.0f;
+ static inline const void* get_alpha() { static const float val = alpha; return &val; }
+ static inline const void* get_beta() { static const float val = beta; return &val; }
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
+};
+
+template<>
+struct batched_mul_mat_traits<GGML_TYPE_F16> {
+ using cuda_type = half;
+ static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
+ static inline const cudaDataType_t data_type = CUDA_R_16F;
+ static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
+ static inline const half alpha = 1.0;
+ static inline const half beta = 0.0;
+ static inline const void* get_alpha() { static const half val = alpha; return &val; }
+ static inline const void* get_beta() { static const half val = beta; return &val; }
+ static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
+};
+
+template<ggml_type src0_type>
+static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ using traits = batched_mul_mat_traits<src0_type>;
+ using cuda_t = typename traits::cuda_type;
+
+ GGML_ASSERT(!ggml_is_transposed(src0));
+ GGML_ASSERT(!ggml_is_transposed(src1));
+ GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
+ GGML_ASSERT(src0->type == src0_type);
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
+ // As long as dst is contiguous this does not matter though.
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t ne_dst = ggml_nelements(dst);
+ cudaStream_t main_stream = ctx.stream();
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
+
+ float * dst_ddf = (float *) dst->data;
+ const size_t ts_src1 = ggml_type_size(src1->type);
+ GGML_ASSERT(nb10 == ts_src1);
+ int64_t s11 = nb11 / ts_src1;
+ int64_t s12 = nb12 / ts_src1;
+ int64_t s13 = nb13 / ts_src1;
+
+ const cuda_t * src0_ptr = nullptr;
+ const cuda_t * src1_ptr = nullptr;
+
+ ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
+ ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
+
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
+
+ // Handle src0
+ src0_ptr = (const cuda_t *) src0->data;
+
+ // Handle src1 - convert if necessary
+ if (src1->type == src0_type) {
+ src1_ptr = (const cuda_t *) src1->data;
+ } else {
+ // Convert src1 to target type using traits conversion functions
+ const int64_t ne_src1 = ggml_nelements(src1);
+ src1_alloc.alloc(ne_src1);
+
+ const auto convert_func = traits::get_nc_converter(src1->type);
+ GGML_ASSERT(convert_func != nullptr);
+ convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
+ src1_ptr = src1_alloc.get();
+ s11 = ne10;
+ s12 = ne11*s11;
+ s13 = ne12*s12;
+
+ is_src1_cont_2 = true;
+ }
+
+ // Setup destination buffer
+ ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
+ char * dst_t;
+ size_t nbd2 = dst->nb[2];
+ size_t nbd3 = dst->nb[3];
+
+ cublasComputeType_t cu_compute_type = traits::compute_type;
+ cudaDataType_t cu_data_type = traits::data_type;
+ cudaDataType_t cu_data_type_a = traits::data_type;
+ cudaDataType_t cu_data_type_b = traits::data_type;
+ const void * alpha = traits::get_alpha();
+ const void * beta = traits::get_beta();
+ const float alpha_f32 = 1.0f;
+ const float beta_f32 = 0.0f;
+
+ if (dst->op_params[0] == GGML_PREC_DEFAULT) {
+ if constexpr (src0_type == GGML_TYPE_F32) {
+ dst_t = (char *) dst_ddf; // Direct F32 output
+ } else {
+ dst_t = (char *) dst_temp.alloc(ne_dst);
+ nbd2 /= sizeof(float) / sizeof(cuda_t);
+ nbd3 /= sizeof(float) / sizeof(cuda_t);
+ }
+ } else {
+ dst_t = (char *) dst_ddf;
+ cu_compute_type = CUBLAS_COMPUTE_32F;
+ cu_data_type = CUDA_R_32F;
+ alpha = &alpha_f32;
+ beta = &beta_f32;
+ }
+
+ int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+ cu_compute_type = CUBLAS_COMPUTE_32F;
+ alpha = &alpha_f32;
+ beta = &beta_f32;
+ }
+
+ GGML_ASSERT(ne12 % ne02 == 0);
+ GGML_ASSERT(ne13 % ne03 == 0);
+
+ // broadcast factors
+ const int64_t r2 = ne12/ne02;
+ const int64_t r3 = ne13/ne03;
+
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
+ const int64_t smb = ne12 == 1 ? s13 : s12;
+
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+ // use cublasGemmStridedBatchedEx
+ CUBLAS_CHECK(
+ cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
+ src1_ptr, cu_data_type_b, s11, smb, // strideB
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
+ ne12*ne13,
+ cu_compute_type,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+ } else {
+ // use cublasGemmBatchedEx
+ const int64_t ne23 = ne12*ne13;
+
+ ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
+ ggml_cuda_pool_alloc< void *> ptrs_dst(ctx.pool(), 1*ne23);
+
+ size_t src1_stride_size = sizeof(cuda_t);
+
+ const int threads_x = 16;
+ const int threads_y = 16;
+ dim3 block_dims(threads_x, threads_y);
+
+ dim3 grid_dims(
+ (ne13 + threads_x - 1) / threads_x,
+ (ne12 + threads_y - 1) / threads_y
+ );
+ k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
+ src0_ptr, src1_ptr, dst_t,
+ ptrs_src.get(), ptrs_dst.get(),
+ ne12, ne13,
+ ne23,
+ nb02, nb03,
+ (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
+ (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
+ nbd2, nbd3,
+ r2, r3);
+
+ CUDA_CHECK(cudaGetLastError());
+
+ CUBLAS_CHECK(
+ cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+ ne01, ne11, ne10,
+ alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
+ (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
+ beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
+ ne23,
+ cu_compute_type,
+ CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+ }
+
+ // Convert output back to F32 if needed
+ if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
+ const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
+ to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
+ }
+}
+
+static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
+ break;
+ case GGML_TYPE_BF16:
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
+ break;
+ case GGML_TYPE_F16:
+ ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
+ break;
+ default:
+ GGML_ABORT("Unsupported type");
+ }
+}
+
+static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
+ const ggml_tensor * ffn_gate,
+ const ggml_tensor * glu,
+ const ggml_tensor * ffn_up_bias = nullptr,
+ const ggml_tensor * ffn_gate_bias = nullptr) {
+ const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
+
+ if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
+ return false;
+ }
+
+ const bool is_mul_mat = ffn_up->op == GGML_OP_MUL_MAT && ffn_gate->op == GGML_OP_MUL_MAT && glu->op == GGML_OP_GLU;
+ const bool is_mul_mat_id = ffn_up->op == GGML_OP_MUL_MAT_ID && ffn_gate->op == GGML_OP_MUL_MAT_ID && glu->op == GGML_OP_GLU;
+
+ GGML_ASSERT(ffn_up && ffn_gate && glu);
+
+ if (!is_mul_mat && !is_mul_mat_id) {
+ return false;
+ }
+
+ const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
+
+ if (has_bias) {
+ if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
+ return false;
+ }
+
+ if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
+ return false;
+ }
+
+ if (expected_bias_op == GGML_OP_ADD) {
+ const bool up_has_mul = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
+ const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
+ if (!up_has_mul || !gate_has_mul) {
+ return false;
+ }
+ } else { // GGML_OP_ADD_ID
+ if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
+ return false;
+ }
+ if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
+ return false;
+ }
+ }
+ } else {
+ if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
+ return false;
+ }
+ }
+
+ if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
+ !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
+ return false;
+ }
+
+ if (ffn_up->src[1] != ffn_gate->src[1]) {
+ return false;
+ }
+
+ if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
+ return false;
+ }
+
+ static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
+
+ if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
+ return false;
+ }
+
+ if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
+ return false;
+ }
+
+ const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
+ ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
+
+ //TODO: add support for fusion for split buffers
+ if (split) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
+ ggml_tensor * src0 = tensor->src[0];
+ ggml_tensor * src1 = tensor->src[1];
+ const ggml_tensor * dst = tensor;
+
+ const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
+
+ bool use_mul_mat_vec_f =
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
+ src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
+
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
+
+ //TODO: add support for fusion for split buffers
+ if (split) {
+ return false;
+ }
+
+ //we only support fusion for ncols_dst = 1
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
+ return false;
+ }
+
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
+ return false;
+ }
+
+
+ return use_mul_mat_vec_f;
+}
+
+static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
+ ggml_tensor * src0 = tensor->src[0];
+ ggml_tensor * src1 = tensor->src[1];
+ const ggml_tensor * dst = tensor;
+
+ const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
+ ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
+ src0->view_src;
+
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
+ dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+
+ // fusion is not universally faster on Pascal
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ if (cc <= GGML_CUDA_CC_PASCAL) {
+ return false;
+ }
+ //we only support fusion for ncols_dst = 1
+ if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
+ return false;
+ }
+
+ if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
+ return false;
+ }
+
+
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
+ ggml_backend_buft_is_cuda_split(src1->buffer->buft);
+
+ //TODO: add support for fusion for split buffers
+ if (split) {
+ return false;
+ }
+
+ return use_mul_mat_vec_q;
+}
+
+static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
+
+ // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
+ // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.
+ // Therefore, in such cases use cuBLAS.
+ const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
+ && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
+
+ bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+ bool use_mul_mat_f = !ggml_is_quantized(src0->type)
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+ bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+ bool any_gpus_with_slow_fp16 = false;
+
+ if (split) {
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
+ auto & tensor_split = buft_ctx->tensor_split;
+ for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
+ // skip devices that are not going to do any work:
+ if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
+ continue;
+ }
+
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
+ }
+ } else {
+ const int cc = ggml_cuda_info().devices[ctx.device].cc;
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
+ use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
+ use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc);
+ }
+
+ // debug helpers
+ //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
+ //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
+ //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
+ //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
+ //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
+ //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
+
+ //TODO update for generic tensor parallelism
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ bool use_batched_cublas_f16 = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
+ bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
+ bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32;
+
+ if (!split && use_mul_mat_vec_f) {
+ // the custom F16 vector kernel can be used over batched cuBLAS GEMM
+ // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
+ } else if (!split && use_mul_mat_f) {
+ ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
+ } else if (!split && use_mul_mat_vec_q) {
+ ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
+ } else if (!split && use_mul_mat_q) {
+ ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
+ } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+ // general KQ + KQV multi-batch without FlashAttention
+ ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
+ } else if (use_mul_mat_vec_f) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
+ } else if (use_mul_mat_vec_q) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
+ } else if (use_mul_mat_q) {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
+ } else {
+ ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
+ }
+}
+
+static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * ids = dst->src[2];
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+ if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
+ if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
+ if (ggml_is_quantized(src0->type)) {
+ if (ne2 <= 4) {
+ ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+ return;
+ }
+ } else {
+ if (GGML_CUDA_CC_IS_AMD(cc)) {
+ ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
+ return;
+ }
+ }
+ }
+
+ if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
+ ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
+ return;
+ }
+
+ if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
+ ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
+ return;
+ }
+ }
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(nb12 % nb11 == 0);
+ GGML_ASSERT(nb2 % nb1 == 0);
+
+ const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc))
+ || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type;
+ const ggml_type type_dst_sorted = GGML_TYPE_F32;
+ const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted);
+ const size_t ts_dst_sorted = ggml_type_size(type_dst_sorted);
+
+ const int64_t n_expert_used = ids->ne[0];
+ const int64_t ne_get_rows = ne12 * n_expert_used;
+
+ std::vector<int32_t> ids_to_sorted_host;
+ ids_to_sorted_host.reserve(2*ne_get_rows);
+ std::vector<int32_t> ids_from_sorted_host(ne_get_rows);
+
+ ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows);
+
+ std::vector<int32_t> tokens_per_expert(ne02);
+
+ ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted);
+ ggml_cuda_pool_alloc<char> dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted);
+
+ std::vector<char> ids_host(ggml_nbytes(ids));
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
+ CUDA_CHECK(cudaStreamSynchronize(stream));
+
+ for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
+ for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
+ for (int64_t iex = 0; iex < n_expert_used; ++iex) {
+ const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
+ assert(expert_to_use >= 0 && expert_to_use < ne02);
+ if (expert_to_use == i02) {
+ ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size();
+ ids_to_sorted_host.push_back(i12*ne11 + iex % ne11);
+ tokens_per_expert[i02]++;
+ break;
+ }
+ }
+ }
+ }
+ GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows));
+
+ ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end());
+
+ CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
+ CUDA_CHECK(cudaStreamSynchronize(stream));
+
+ const int32_t * ids_to_sorted = ids_buf_dev.ptr + 0*ne_get_rows;
+ const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows;
+
+ get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted,
+ ne10, nb11, nb12, nb13,
+ ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
+ ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);
+ CUDA_CHECK(cudaGetLastError());
+
+ char * src1_data_cur = (char *) src1_sorted.ptr;
+ char * dst_data_cur = (char *) dst_sorted.ptr;
+ for (int64_t i02 = 0; i02 < ne02; ++i02) {
+ if (tokens_per_expert[i02] == 0) {
+ continue;
+ }
+
+ ggml_tensor src0_slice = *src0;
+ src0_slice.ne[2] = 1;
+ src0_slice.nb[3] = src0_slice.nb[2];
+ src0_slice.op = GGML_OP_VIEW;
+ src0_slice.view_src = dst->src[0]; // non-const pointer to src0
+ src0_slice.data = (char *) src0->data + i02*nb02;
+
+ ggml_tensor src1_slice;
+ memset(&src1_slice, 0, sizeof(src1_slice));
+ src1_slice.buffer = src1->buffer;
+ src1_slice.type = type_src1_sorted;
+ src1_slice.ne[0] = ne10;
+ src1_slice.ne[1] = tokens_per_expert[i02];
+ src1_slice.ne[2] = 1;
+ src1_slice.ne[3] = 1;
+ src1_slice.nb[0] = ts_src1_sorted;
+ src1_slice.nb[1] = src1_slice.ne[0] * src1_slice.nb[0];
+ src1_slice.nb[2] = src1_slice.ne[1] * src1_slice.nb[1];
+ src1_slice.nb[3] = src1_slice.ne[2] * src1_slice.nb[2];
+ src1_slice.data = src1_data_cur;
+
+ ggml_tensor dst_slice;
+ memset(&dst_slice, 0, sizeof(dst_slice));
+ dst_slice.buffer = dst->buffer;
+ dst_slice.type = type_dst_sorted;
+ dst_slice.ne[0] = ne0;
+ dst_slice.ne[1] = tokens_per_expert[i02];
+ dst_slice.ne[2] = 1;
+ dst_slice.ne[3] = 1;
+ dst_slice.nb[0] = ts_dst_sorted;
+ dst_slice.nb[1] = dst_slice.ne[0] * dst_slice.nb[0];
+ dst_slice.nb[2] = dst_slice.ne[1] * dst_slice.nb[1];
+ dst_slice.nb[3] = dst_slice.ne[2] * dst_slice.nb[2];
+ dst_slice.data = dst_data_cur;
+
+ ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);
+ CUDA_CHECK(cudaGetLastError());
+
+ src1_data_cur += src1_slice.nb[2];
+ dst_data_cur += dst_slice.nb[2];
+ }
+
+ get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,
+ ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,
+ ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
+ nb1, nb2, nb3, stream);
+}
+
+static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
+ // why is this here instead of mul_mat?
+ if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
+ ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
+ }
+
+ switch (dst->op) {
+ case GGML_OP_ARGMAX:
+ ggml_cuda_argmax(ctx, dst);
+ break;
+ case GGML_OP_COUNT_EQUAL:
+ ggml_cuda_count_equal(ctx, dst);
+ break;
+ case GGML_OP_REPEAT:
+ ggml_cuda_op_repeat(ctx, dst);
+ break;
+ case GGML_OP_REPEAT_BACK:
+ ggml_cuda_op_repeat_back(ctx, dst);
+ break;
+ case GGML_OP_GET_ROWS:
+ ggml_cuda_op_get_rows(ctx, dst);
+ break;
+ case GGML_OP_GET_ROWS_BACK:
+ ggml_cuda_op_get_rows_back(ctx, dst);
+ break;
+ case GGML_OP_SET_ROWS:
+ ggml_cuda_op_set_rows(ctx, dst);
+ break;
+ case GGML_OP_SET:
+ ggml_cuda_op_set(ctx, dst);
+ break;
+ case GGML_OP_DUP:
+ ggml_cuda_dup(ctx, dst);
+ break;
+ case GGML_OP_CPY:
+ ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
+ break;
+ case GGML_OP_CONT:
+ ggml_cuda_dup(ctx, dst);
+ break;
+ case GGML_OP_ADD:
+ case GGML_OP_ADD1: // TODO: more efficient implementation
+ ggml_cuda_op_add(ctx, dst);
+ break;
+ case GGML_OP_ADD_ID:
+ ggml_cuda_op_add_id(ctx, dst);
+ break;
+ case GGML_OP_SUB:
+ ggml_cuda_op_sub(ctx, dst);
+ break;
+ case GGML_OP_ACC:
+ ggml_cuda_op_acc(ctx, dst);
+ break;
+ case GGML_OP_MUL:
+ ggml_cuda_op_mul(ctx, dst);
+ break;
+ case GGML_OP_DIV:
+ ggml_cuda_op_div(ctx, dst);
+ break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(dst)) {
+ case GGML_UNARY_OP_ABS:
+ ggml_cuda_op_abs(ctx, dst);
+ break;
+ case GGML_UNARY_OP_SGN:
+ ggml_cuda_op_sgn(ctx, dst);
+ break;
+ case GGML_UNARY_OP_NEG:
+ ggml_cuda_op_neg(ctx, dst);
+ break;
+ case GGML_UNARY_OP_STEP:
+ ggml_cuda_op_step(ctx, dst);
+ break;
+ case GGML_UNARY_OP_GELU:
+ ggml_cuda_op_gelu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_SILU:
+ ggml_cuda_op_silu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_GELU_ERF:
+ ggml_cuda_op_gelu_erf(ctx, dst);
+ break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ ggml_cuda_op_gelu_quick(ctx, dst);
+ break;
+ case GGML_UNARY_OP_TANH:
+ ggml_cuda_op_tanh(ctx, dst);
+ break;
+ case GGML_UNARY_OP_RELU:
+ ggml_cuda_op_relu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_SIGMOID:
+ ggml_cuda_op_sigmoid(ctx, dst);
+ break;
+ case GGML_UNARY_OP_HARDSIGMOID:
+ ggml_cuda_op_hardsigmoid(ctx, dst);
+ break;
+ case GGML_UNARY_OP_HARDSWISH:
+ ggml_cuda_op_hardswish(ctx, dst);
+ break;
+ case GGML_UNARY_OP_EXP:
+ ggml_cuda_op_exp(ctx, dst);
+ break;
+ case GGML_UNARY_OP_ELU:
+ ggml_cuda_op_elu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_XIELU:
+ ggml_cuda_op_xielu(ctx, dst);
+ break;
+ case GGML_UNARY_OP_FLOOR:
+ ggml_cuda_op_floor(ctx, dst);
+ break;
+ case GGML_UNARY_OP_CEIL:
+ ggml_cuda_op_ceil(ctx, dst);
+ break;
+ case GGML_UNARY_OP_ROUND:
+ ggml_cuda_op_round(ctx, dst);
+ break;
+ case GGML_UNARY_OP_TRUNC:
+ ggml_cuda_op_trunc(ctx, dst);
+ break;
+ case GGML_UNARY_OP_EXPM1:
+ ggml_cuda_op_expm1(ctx, dst);
+ break;
+ case GGML_UNARY_OP_SOFTPLUS:
+ ggml_cuda_op_softplus(ctx, dst);
+ break;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_GLU:
+ switch (ggml_get_glu_op(dst)) {
+ case GGML_GLU_OP_REGLU:
+ ggml_cuda_op_reglu(ctx, dst);
+ break;
+ case GGML_GLU_OP_GEGLU:
+ ggml_cuda_op_geglu(ctx, dst);
+ break;
+ case GGML_GLU_OP_SWIGLU:
+ ggml_cuda_op_swiglu(ctx, dst);
+ break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ ggml_cuda_op_swiglu_oai(ctx, dst);
+ break;
+ case GGML_GLU_OP_GEGLU_ERF:
+ ggml_cuda_op_geglu_erf(ctx, dst);
+ break;
+ case GGML_GLU_OP_GEGLU_QUICK:
+ ggml_cuda_op_geglu_quick(ctx, dst);
+ break;
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_NORM:
+ ggml_cuda_op_norm(ctx, dst);
+ break;
+ case GGML_OP_GROUP_NORM:
+ ggml_cuda_op_group_norm(ctx, dst);
+ break;
+ case GGML_OP_L2_NORM:
+ ggml_cuda_op_l2_norm(ctx, dst);
+ break;
+ case GGML_OP_CONCAT:
+ ggml_cuda_op_concat(ctx, dst);
+ break;
+ case GGML_OP_UPSCALE:
+ ggml_cuda_op_upscale(ctx, dst);
+ break;
+ case GGML_OP_PAD:
+ ggml_cuda_op_pad(ctx, dst);
+ break;
+ case GGML_OP_PAD_REFLECT_1D:
+ ggml_cuda_op_pad_reflect_1d(ctx, dst);
+ break;
+ case GGML_OP_ARANGE:
+ ggml_cuda_op_arange(ctx, dst);
+ break;
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ ggml_cuda_op_timestep_embedding(ctx, dst);
+ break;
+ case GGML_OP_LEAKY_RELU:
+ ggml_cuda_op_leaky_relu(ctx, dst);
+ break;
+ case GGML_OP_SILU_BACK:
+ ggml_cuda_op_silu_back(ctx, dst);
+ break;
+ case GGML_OP_RMS_NORM:
+ ggml_cuda_op_rms_norm(ctx, dst);
+ break;
+ case GGML_OP_RMS_NORM_BACK:
+ ggml_cuda_op_rms_norm_back(ctx, dst);
+ break;
+ case GGML_OP_MUL_MAT:
+ ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
+ break;
+ case GGML_OP_MUL_MAT_ID:
+ ggml_cuda_mul_mat_id(ctx, dst);
+ break;
+ case GGML_OP_OUT_PROD:
+ ggml_cuda_out_prod(ctx, dst);
+ break;
+ case GGML_OP_SCALE:
+ ggml_cuda_op_scale(ctx, dst);
+ break;
+ case GGML_OP_SQR:
+ ggml_cuda_op_sqr(ctx, dst);
+ break;
+ case GGML_OP_SQRT:
+ ggml_cuda_op_sqrt(ctx, dst);
+ break;
+ case GGML_OP_SIN:
+ ggml_cuda_op_sin(ctx, dst);
+ break;
+ case GGML_OP_COS:
+ ggml_cuda_op_cos(ctx, dst);
+ break;
+ case GGML_OP_CLAMP:
+ ggml_cuda_op_clamp(ctx, dst);
+ break;
+ case GGML_OP_LOG:
+ ggml_cuda_op_log(ctx, dst);
+ break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ break;
+ case GGML_OP_DIAG:
+ ggml_cuda_op_diag(ctx, dst);
+ break;
+ case GGML_OP_DIAG_MASK_INF:
+ ggml_cuda_op_diag_mask_inf(ctx, dst);
+ break;
+ case GGML_OP_SOFT_MAX:
+ ggml_cuda_op_soft_max(ctx, dst);
+ break;
+ case GGML_OP_SOFT_MAX_BACK:
+ ggml_cuda_op_soft_max_back(ctx, dst);
+ break;
+ case GGML_OP_ROPE:
+ ggml_cuda_op_rope(ctx, dst);
+ break;
+ case GGML_OP_ROPE_BACK:
+ ggml_cuda_op_rope_back(ctx, dst);
+ break;
+ case GGML_OP_ROLL:
+ ggml_cuda_op_roll(ctx, dst);
+ break;
+ case GGML_OP_IM2COL:
+ ggml_cuda_op_im2col(ctx, dst);
+ break;
+ case GGML_OP_IM2COL_3D:
+ ggml_cuda_op_im2col_3d(ctx, dst);
+ break;
+ case GGML_OP_CONV_2D:
+ ggml_cuda_op_conv2d(ctx, dst);
+ break;
+ case GGML_OP_CONV_2D_DW:
+ ggml_cuda_op_conv2d_dw(ctx, dst);
+ break;
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ ggml_cuda_conv_2d_transpose_p0(ctx, dst);
+ break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ ggml_cuda_op_conv_transpose_1d(ctx,dst);
+ break;
+ case GGML_OP_POOL_2D:
+ ggml_cuda_op_pool2d(ctx, dst);
+ break;
+ case GGML_OP_SUM:
+ ggml_cuda_op_sum(ctx, dst);
+ break;
+ case GGML_OP_CUMSUM:
+ ggml_cuda_op_cumsum(ctx, dst);
+ break;
+ case GGML_OP_SUM_ROWS:
+ ggml_cuda_op_sum_rows(ctx, dst);
+ break;
+ case GGML_OP_MEAN:
+ ggml_cuda_op_mean(ctx, dst);
+ break;
+ case GGML_OP_SSM_CONV:
+ ggml_cuda_op_ssm_conv(ctx, dst);
+ break;
+ case GGML_OP_SSM_SCAN:
+ ggml_cuda_op_ssm_scan(ctx, dst);
+ break;
+ case GGML_OP_TOP_K:
+ ggml_cuda_op_top_k(ctx, dst);
+ break;
+ case GGML_OP_ARGSORT:
+ ggml_cuda_op_argsort(ctx, dst);
+ break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_cuda_flash_attn_ext(ctx, dst);
+ break;
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ ggml_cuda_cross_entropy_loss(ctx, dst);
+ break;
+ case GGML_OP_TRI:
+ ggml_cuda_op_tri(ctx, dst);
+ break;
+ case GGML_OP_RWKV_WKV6:
+ ggml_cuda_op_rwkv_wkv6(ctx, dst);
+ break;
+ case GGML_OP_GATED_LINEAR_ATTN:
+ ggml_cuda_op_gated_linear_attn(ctx, dst);
+ break;
+ case GGML_OP_RWKV_WKV7:
+ ggml_cuda_op_rwkv_wkv7(ctx, dst);
+ break;
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ ggml_cuda_cross_entropy_loss_back(ctx, dst);
+ break;
+ case GGML_OP_OPT_STEP_ADAMW:
+ ggml_cuda_opt_step_adamw(ctx, dst);
+ break;
+ case GGML_OP_OPT_STEP_SGD:
+ ggml_cuda_opt_step_sgd(ctx, dst);
+ break;
+ case GGML_OP_SOLVE_TRI:
+ ggml_cuda_op_solve_tri(ctx, dst);
+ break;
+ case GGML_OP_FILL:
+ ggml_cuda_op_fill(ctx, dst);
+ break;
+ default:
+ return false;
+ }
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess) {
+ GGML_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
+ CUDA_CHECK(err);
+ }
+
+ return true;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+
+// backend
+
+static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ return cuda_ctx->name.c_str();
+}
+
+static void ggml_backend_cuda_free(ggml_backend_t backend) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ delete cuda_ctx;
+ delete backend;
+}
+
+static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+ CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
+}
+
+static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+ ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+ GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
+
+ CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
+}
+
+static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
+ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
+ ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+ if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
+ return false;
+ }
+
+ if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
+ return false;
+ }
+
+ // device -> device copy
+ ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
+ ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
+
+ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
+ ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
+
+ if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
+#ifndef NDEBUG
+ GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
+#endif
+ return false;
+ }
+
+ if (backend_src != backend_dst) {
+ // copy on src stream
+ if (cuda_ctx_src->device == cuda_ctx_dst->device) {
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+ } else {
+#ifdef GGML_CUDA_NO_PEER_COPY
+ return false;
+#else
+ CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
+#endif
+ }
+
+ // record event on src stream after the copy
+ if (!cuda_ctx_src->copy_event) {
+ ggml_cuda_set_device(cuda_ctx_src->device);
+ CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
+ }
+
+ CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
+
+ // wait on dst stream for the copy to complete
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
+ } else {
+ // src and dst are on the same backend
+ CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
+ }
+ return true;
+}
+
+static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));
+
+ GGML_UNUSED(backend);
+}
+
+#ifdef USE_CUDA_GRAPH
+static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
+
+ bool use_cuda_graph = true;
+ // Loop over nodes in GGML graph to obtain info needed for CUDA graph
+
+ const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
+ const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
+ const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
+ const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
+ const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
+ const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
+ const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ continue;
+ }
+
+ if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
+ use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
+#ifndef NDEBUG
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
+#endif
+ }
+
+ if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
+ use_cuda_graph = false; // This node type is not supported by CUDA graph capture
+#ifndef NDEBUG
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
+#endif
+ }
+
+ if (node->op == GGML_OP_ADD &&
+ node->src[1] && node->src[1]->ne[1] > 1 &&
+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
+ strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
+ strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
+ // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
+ // by means of matching node names. See
+ // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
+ // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
+ // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
+ use_cuda_graph = false;
+#ifndef NDEBUG
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+#endif
+ }
+
+ if (!use_cuda_graph) {
+ break;
+ }
+ }
+
+ return use_cuda_graph;
+}
+
+static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
+ memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
+ props->node_data = node->data;
+ props->node_op = node->op;
+ props->node_type = node->type;
+ props->flags = node->flags;
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ props->ne[i] = node->ne[i];
+ props->nb[i] = node->nb[i];
+ }
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (!node->src[i]) {
+ continue;
+ }
+
+ props->src_data[i] = node->src[i]->data;
+ }
+ memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
+}
+
+static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
+ if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
+ return false;
+ }
+
+ if (node->op != props->node_op) {
+ return false;
+ }
+
+ if (node->type != props->node_type) {
+ return false;
+ }
+
+ for (int i = 0; i < GGML_MAX_DIMS; i++) {
+ if (node->ne[i] != props->ne[i]) {
+ return false;
+ }
+ if (node->nb[i] != props->nb[i]) {
+ return false;
+ }
+ }
+
+ if (node->op != GGML_OP_VIEW) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (!node->src[i]) {
+ if (props->src_data[i] != nullptr) {
+ return false;
+ }
+ continue;
+ }
+
+ if (node->src[i]->data != props->src_data[i]) {
+ return false;
+ }
+ }
+ }
+
+ if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
+ return false;
+ }
+
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
+ return false;
+ }
+
+ return true;
+}
+
+static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
+ return cgraph->nodes[0];
+}
+
+static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
+ bool res = false;
+
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+
+ if (graph->instance == nullptr) {
+ res = true;
+ }
+
+ // Check if the graph size has changed
+ if (graph->props.size() != (size_t)cgraph->n_nodes) {
+ res = true;
+ graph->props.resize(cgraph->n_nodes);
+ }
+
+ // Loop over nodes in GGML graph to determine if CUDA graph update is required
+ // and store properties to allow this comparison for the next token
+ std::unordered_set<ggml_tensor *> seen_node;
+ std::vector<ggml_tensor *> srcs_extra;
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ bool props_match = true;
+
+ seen_node.insert(cgraph->nodes[i]);
+
+ if (!res) {
+ props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
+ }
+ if (!props_match) {
+ res = true;
+ }
+ ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
+
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+ ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
+ if (src && seen_node.find(src) == seen_node.end()) {
+ srcs_extra.push_back(src);
+ }
+ }
+ }
+
+ if (graph->extra.size() != (size_t) srcs_extra.size()) {
+ res = true;
+ graph->extra.resize(srcs_extra.size());
+ }
+
+ for (size_t i = 0; i < srcs_extra.size(); ++i) {
+ bool props_match = true;
+
+ if (!res) {
+ props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
+ }
+
+ if (!props_match) {
+ res = true;
+ }
+ ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
+ }
+
+ return res;
+}
+
+static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+
+#if CUDART_VERSION >= 12000
+ cudaGraphExecUpdateResultInfo result_info;
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
+#else
+ cudaGraphNode_t errorNode;
+ cudaGraphExecUpdateResult result_info;
+ cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
+#endif // CUDART_VERSION >= 12000
+
+ if (stat == cudaErrorGraphExecUpdateFailure) {
+#ifndef NDEBUG
+ GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
+#endif
+
+ // The pre-existing graph exec cannot be updated due to violated constraints
+ // so instead clear error and re-instantiate
+ (void)cudaGetLastError();
+ CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
+ graph->instance = nullptr;
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
+ } else {
+ GGML_ASSERT(stat == cudaSuccess);
+ }
+}
+#endif // USE_CUDA_GRAPH
+
+static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
+ const ggml_tensor * view,
+ const ggml_tensor * set_rows) {
+
+ if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
+ return false;
+ }
+ // ne3 not tested
+ if (rope->src[0]->ne[3] != 1) {
+ return false;
+ }
+
+ if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
+ return false;
+ }
+
+ if (set_rows->src[1]->type != GGML_TYPE_I64) {
+ return false;
+ }
+
+ // The view should flatten two dims of rope into one dim
+ if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
+ return false;
+ }
+
+ // Only norm/neox shaders have the fusion code
+ const int mode = ((const int32_t *) rope->op_params)[2];
+ if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
+ return false;
+ }
+
+ return true;
+}
+
+static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
+ args.sigmoid = false;
+ args.softmax = false;
+ args.delayed_softmax = false;
+ args.prob_bias = false;
+ args.norm = false;
+
+ const int n_nodes = cgraph->n_nodes;
+ ggml_tensor ** nodes = cgraph->nodes;
+
+ if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
+ args.softmax = true;
+ }
+
+ if (nodes[node_idx]->op == GGML_OP_UNARY) {
+ if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
+ return false;
+ }
+ args.sigmoid = true;
+ }
+
+ if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
+ args.delayed_softmax = true;
+ }
+
+ node_idx++;
+
+ if (args.sigmoid || args.softmax) {
+ // SOFTMAX -> RESHAPE
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+ return false;
+ }
+ ggml_tensor * probs_reshaped = nodes[node_idx];
+ node_idx++;
+
+ if (node_idx >= n_nodes) {
+ return false;
+ }
+
+ // src of bias add is the unreshaped probs (-2 instead of -1)
+ if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
+ args.prob_bias = true;
+ node_idx++;
+ }
+ // RESHAPE/ADD -> ARGSORT
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
+ return false;
+ }
+
+ if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+ return false;
+ } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
+ return false;
+ }
+
+ node_idx++;
+
+ // ARGSORT-> VIEW
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+ return false;
+ }
+ node_idx++;
+
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
+ return false;
+ }
+
+ // GET_ROWS
+ if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
+ return false;
+ }
+ node_idx++;
+ } else if (args.delayed_softmax) {
+ if (node_idx - 2 < 0) {
+ return false;
+ }
+ ggml_tensor * probs_reshaped = nodes[node_idx - 2];
+
+ // VIEW->ARGSORT
+ if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
+ nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+ return false;
+ }
+ node_idx++;
+
+ // GET_ROWS
+ if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+ nodes[node_idx]->src[0] != probs_reshaped) {
+ return false;
+ }
+ node_idx++;
+
+ static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+ for (const ggml_op op : remaining_ops) {
+ if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+ return false;
+ }
+ node_idx++;
+ }
+ }
+
+ // At this point we can check for norm + scale. Everything is now at least valid till the norm
+ if (node_idx >= n_nodes) {
+ return true;
+ }
+
+ if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
+ //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
+ static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
+
+ args.norm = true;
+ for (const ggml_op op : norm_ops) {
+ if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+ node_idx++;
+ } else {
+ args.norm = false;
+ return true;
+ }
+ }
+
+ // DIV <- CLAMP, RESHAPE
+ if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
+ nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
+ args.norm = false;
+ return true;
+ }
+ node_idx++;
+
+ if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
+ args.norm = false;
+ return true;
+ }
+
+ node_idx++;
+ }
+
+ if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
+ args.scale = true;
+ }
+
+ return true;
+}
+
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
+ int node_idx,
+ std::initializer_list<enum ggml_op> ops,
+ std::initializer_list<enum ggml_unary_op> unary_ops) {
+#ifndef NDEBUG
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
+ GGML_ASSERT(unary_ops.size() == num_unary);
+#endif
+
+ const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
+ const std::initializer_list<enum ggml_op> & list2) {
+ return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
+ };
+
+ std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
+ std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
+
+ std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
+ std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
+
+ if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
+ const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
+ const ggml_tensor * ffn_up_bias = cgraph->nodes[node_idx + 3];
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 4];
+
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
+ return true;
+ }
+ }
+
+ if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
+ ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
+ const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
+ const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
+ const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
+
+ if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
+ return true;
+ }
+ }
+
+ std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
+
+ if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
+ const ggml_tensor * rope = cgraph->nodes[node_idx];
+ const ggml_tensor * view = cgraph->nodes[node_idx + 1];
+ const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
+
+ if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
+ return true;
+ }
+ }
+
+ if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+ return false;
+ }
+
+ if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+ const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
+ const ggml_tensor *mul = cgraph->nodes[node_idx+1];
+ const ggml_tensor *add = nullptr;
+
+ if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
+ add = cgraph->nodes[node_idx+2];
+ }
+
+ GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
+
+ //rms norm only supports F32
+ if (mul->src[0]->type != GGML_TYPE_F32 ||
+ mul->src[1]->type != GGML_TYPE_F32 ||
+ mul->type != GGML_TYPE_F32) {
+ return false;
+ }
+
+ if (add && (add->src[0]->type != GGML_TYPE_F32 ||
+ add->src[1]->type != GGML_TYPE_F32 ||
+ add->type != GGML_TYPE_F32) ) {
+ return false;
+ }
+
+ //if rms norm is the B operand, then we don't handle broadcast
+ if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
+ return false;
+ }
+
+ //rms_norm kernel assumes contigous rows
+ if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
+ return false;
+ }
+
+ if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
+ return false;
+ }
+
+ return true;
+ }
+
+ if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
+ const ggml_tensor *scale = cgraph->nodes[node_idx];
+ const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
+ const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
+
+ GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(scale->type == GGML_TYPE_F32);
+
+ if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
+ return false;
+ }
+
+ // Check for bias
+ if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
+ return false;
+ }
+
+ return true;
+ }
+
+ return false;
+}
+
+static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
+ bool graph_evaluated_or_captured = false;
+
+ // flag used to determine whether it is an integrated_gpu
+ const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
+
+ ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
+ bool is_concurrent_event_active = false;
+ ggml_cuda_concurrent_event * concurrent_event = nullptr;
+ bool should_launch_concurrent_events = false;
+
+ const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
+ if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
+ concurrent_event = &stream_ctx.concurrent_events[node];
+
+ is_concurrent_event_active = true;
+
+ GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
+
+ cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0
+ GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
+ CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
+
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
+ cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
+ CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
+ }
+ }
+ };
+
+ while (!graph_evaluated_or_captured) {
+ // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
+ // With the use of CUDA graphs, the execution will be performed by the graph launch.
+ if (!use_cuda_graph || cuda_graph_update_required) {
+ [[maybe_unused]] int prev_i = 0;
+
+ if (stream_ctx.concurrent_events.size() > 0) {
+ should_launch_concurrent_events = true;
+ for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
+ should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
+ }
+ }
+
+ if (should_launch_concurrent_events) {
+ // Restore original node order within each concurrent region to enable fusion within streams
+
+ std::unordered_map<const ggml_tensor *, int> node_to_idx;
+ node_to_idx.reserve(cgraph->n_nodes);
+ for (int i = 0; i < cgraph->n_nodes; ++i) {
+ node_to_idx[cgraph->nodes[i]] = i;
+ }
+
+ for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
+ // Find positions of all nodes from this event in the current graph
+ std::vector<int> positions;
+ positions.reserve(event.original_order.size());
+
+ bool all_found = true;
+ for (const ggml_tensor * orig_node : event.original_order) {
+ auto it = node_to_idx.find(orig_node);
+ if (it != node_to_idx.end()) {
+ positions.push_back(it->second);
+ } else {
+ all_found = false;
+ break;
+ }
+ }
+
+ if (!all_found || positions.size() != event.original_order.size()) {
+ continue;
+ }
+
+ // Sort positions to get contiguous range
+ std::vector<int> sorted_positions = positions;
+ std::sort(sorted_positions.begin(), sorted_positions.end());
+
+ bool is_contiguous = true;
+ for (size_t i = 1; i < sorted_positions.size(); ++i) {
+ if (sorted_positions[i] != sorted_positions[i-1] + 1) {
+ is_contiguous = false;
+ break;
+ }
+ }
+
+ if (!is_contiguous) {
+ continue;
+ }
+
+ // Restore original order at the sorted positions
+ int start_pos = sorted_positions[0];
+ for (size_t i = 0; i < event.original_order.size(); ++i) {
+ cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
+ }
+ }
+ } else {
+ stream_ctx.concurrent_events.clear();
+ }
+
+ for (int i = 0; i < cgraph->n_nodes; i++) {
+ ggml_tensor * node = cgraph->nodes[i];
+ if (is_concurrent_event_active) {
+ GGML_ASSERT(concurrent_event);
+
+ if (node == concurrent_event->join_node) {
+ cuda_ctx->curr_stream_no = 0;
+ for (int i = 1; i <= concurrent_event->n_streams; ++i) {
+ // Wait on join events of forked streams in the main stream
+ CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
+ cuda_ctx->stream(cuda_ctx->device, i)));
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
+ }
+
+ is_concurrent_event_active = false;
+ concurrent_event = nullptr;
+ } else {
+ GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
+ }
+ } else if (i - prev_i > 1) {
+ //the previous node was fused
+ const ggml_tensor * prev_node = cgraph->nodes[i - 1];
+ try_launch_concurrent_event(prev_node);
+
+ if (is_concurrent_event_active) {
+ cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
+ GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
+ }
+ }
+
+#ifdef GGML_CUDA_DEBUG
+ const int nodes_fused = i - prev_i - 1;
+ if (nodes_fused > 0) {
+ GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
+ }
+#endif
+ prev_i = i;
+
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ continue;
+ }
+
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
+ continue;
+ }
+
+ // start of fusion operations
+ static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
+ if (!disable_fusion) {
+ ggml_cuda_topk_moe_args args;
+
+ if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
+ cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
+ const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
+
+ std::vector<ggml_op> ops;
+
+ if (can_fuse) {
+ const ggml_tensor * logits = node->src[0];
+ ggml_tensor * weights = nullptr;
+ ggml_tensor * ids = nullptr;
+ const ggml_tensor * bias = nullptr;
+ const ggml_tensor * clamp = nullptr;
+ const ggml_tensor * scale = nullptr;
+
+ if (!args.delayed_softmax) {
+ ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
+ int out_nodes[2]; // nodes which can't be elided
+
+ if (args.prob_bias) {
+ bias = cgraph->nodes[i + 2]->src[1];
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
+ GGML_OP_VIEW, GGML_OP_GET_ROWS });
+ out_nodes[0] = i + 4;
+ ids = cgraph->nodes[i + 4];
+ } else {
+ ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
+ GGML_OP_GET_ROWS });
+ out_nodes[0] = i + 3;
+ ids = cgraph->nodes[i + 3];
+ }
+
+ if (args.norm) {
+ ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
+ GGML_OP_DIV, GGML_OP_RESHAPE });
+ clamp = cgraph->nodes[i + ops.size() - 3];
+ }
+ if (args.scale) {
+ ops.insert(ops.end(), { GGML_OP_SCALE });
+ scale = cgraph->nodes[i + ops.size() - 1];
+ }
+
+ weights = cgraph->nodes[i + ops.size() - 1];
+ out_nodes[1] = i + ops.size() - 1;
+
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+ ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+ i += ops.size() - 1;
+ continue;
+ }
+ } else if (!args.norm && !args.prob_bias) {
+ //special case gpt-oss, no norm, no bias.
+ ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
+ GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
+ weights = cgraph->nodes[i + 5];
+ ids = cgraph->nodes[i + 1];
+ const ggml_tensor * softmax = cgraph->nodes[i + 4];
+
+ int out_nodes[2] = { i + 1, i + 5 };
+ if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
+ ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
+ ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
+ i += ops.size() - 1;
+ continue;
+ }
+ }
+ }
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
+ ggml_tensor * rope = cgraph->nodes[i];
+ ggml_tensor * set_rows = cgraph->nodes[i + 2];
+
+ ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
+ i += 2;
+ continue;
+ }
+
+ if (node->op == GGML_OP_ADD) {
+ int n_fuse = 0;
+ ggml_op ops[8];
+ std::fill(ops, ops + 8, GGML_OP_ADD);
+
+ for (; n_fuse <= 6; ++n_fuse){
+ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
+ break;
+ }
+ if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
+ break;
+ }
+ if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
+ break;
+ }
+ }
+
+ n_fuse++;
+
+ if (n_fuse > 1) {
+ for (int j = 0; j < n_fuse - 1; ++j) {
+ node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
+ }
+ cgraph->nodes[i + n_fuse - 1]->data = node->data;
+ ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
+ i += n_fuse - 1;
+
+ continue;
+ }
+ }
+
+ bool fused_mul_mat_vec = false;
+ int fused_node_count = 0;
+
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
+
+ if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
+ ggml_tensor * glu = cgraph->nodes[i + 4];
+ ggml_tensor * gate_bias_n = glu->src[0];
+ ggml_tensor * up_bias_n = glu->src[1];
+
+ //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
+ ggml_tensor * gate_n = nullptr;
+ ggml_tensor * up_n = nullptr;
+
+ if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
+ gate_n = cgraph->nodes[i];
+ up_n = cgraph->nodes[i + 2];
+ } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
+ gate_n = cgraph->nodes[i + 2];
+ up_n = cgraph->nodes[i];
+ } else {
+ continue;
+ }
+
+ auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
+ if (op_bias == GGML_OP_ADD) {
+ if (bias_node->src[0] == mul_node) {
+ return bias_node->src[1];
+ }
+ if (bias_node->src[1] == mul_node) {
+ return bias_node->src[0];
+ }
+ return (ggml_tensor *) nullptr;
+ }
+ GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
+ GGML_ASSERT(bias_node->src[0] == mul_node);
+ return bias_node->src[1];
+ };
+
+ ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op);
+ ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
+
+ if (!up_bias_tensor || !gate_bias_tensor) {
+ continue;
+ }
+
+ // we don't support repeating adds
+ if (bias_op == GGML_OP_ADD &&
+ (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
+ !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
+ continue;
+ }
+
+ const ggml_tensor * src0 = up_n->src[0];
+ const ggml_tensor * src1 = up_n->src[1];
+ const ggml_tensor * ids = up_n->src[2];
+
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
+ ggml_cuda_mm_fusion_args_host fusion_data{};
+ fusion_data.gate = gate_n->src[0];
+ fusion_data.x_bias = up_bias_tensor;
+ fusion_data.gate_bias = gate_bias_tensor;
+ fusion_data.glu_op = ggml_get_glu_op(glu);
+
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+ fused_mul_mat_vec = true;
+ fused_node_count = 5;
+ break;
+ }
+
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
+ ggml_cuda_mm_fusion_args_host fusion_data{};
+ fusion_data.gate = gate_n->src[0];
+ fusion_data.x_bias = up_bias_tensor;
+ fusion_data.gate_bias = gate_bias_tensor;
+ fusion_data.glu_op = ggml_get_glu_op(glu);
+
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+ fused_mul_mat_vec = true;
+ fused_node_count = 5;
+ break;
+ }
+ } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
+ ggml_tensor * glu = cgraph->nodes[i + 2];
+ ggml_tensor * gate = glu->src[0];
+ ggml_tensor * up = glu->src[1];
+
+ bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
+ || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
+
+ if (!ok) continue;
+
+ const ggml_tensor * src0 = up->src[0];
+ const ggml_tensor * src1 = up->src[1];
+ const ggml_tensor * ids = up->src[2];
+
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
+ ggml_cuda_mm_fusion_args_host fusion_data{};
+ fusion_data.gate = gate->src[0];
+ fusion_data.glu_op = ggml_get_glu_op(glu);
+
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+ fused_mul_mat_vec = true;
+ fused_node_count = 3;
+ break;
+ }
+
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
+ ggml_cuda_mm_fusion_args_host fusion_data{};
+ fusion_data.gate = gate->src[0];
+ fusion_data.glu_op = ggml_get_glu_op(glu);
+
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
+ fused_mul_mat_vec = true;
+ fused_node_count = 3;
+ break;
+ }
+ }
+ }
+
+ if (fused_mul_mat_vec) {
+ i += fused_node_count - 1;
+ continue;
+ }
+
+ fused_mul_mat_vec = false;
+ fused_node_count = 0;
+
+ for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
+ const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
+
+ if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
+ continue;
+ }
+
+ ggml_tensor * mm_node = cgraph->nodes[i];
+ ggml_tensor * bias_node = cgraph->nodes[i + 1];
+
+ ggml_tensor * bias_tensor = nullptr;
+ if (bias_op == GGML_OP_ADD) {
+ if (bias_node->src[0] == mm_node) {
+ bias_tensor = bias_node->src[1];
+ } else if (bias_node->src[1] == mm_node) {
+ bias_tensor = bias_node->src[0];
+ } else {
+ continue;
+ }
+ } else {
+ if (bias_node->src[0] != mm_node) {
+ continue;
+ }
+ bias_tensor = bias_node->src[1];
+ }
+
+ const ggml_tensor * src0 = mm_node->src[0];
+ const ggml_tensor * src1 = mm_node->src[1];
+ const ggml_tensor * ids = mm_node->src[2];
+
+ if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
+ continue;
+ }
+
+ if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
+ continue;
+ }
+
+ ggml_cuda_mm_fusion_args_host fusion_data{};
+ fusion_data.x_bias = bias_tensor;
+
+ if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
+ ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
+ fused_mul_mat_vec = true;
+ fused_node_count = 2;
+ break;
+ }
+
+ if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
+ ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
+ fused_mul_mat_vec = true;
+ fused_node_count = 2;
+ break;
+ }
+ }
+
+ if (fused_mul_mat_vec) {
+ i += fused_node_count - 1;
+ continue;
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
+ ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
+ i += 2;
+ continue;
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
+ i++;
+ continue;
+ }
+
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
+ i += 2;
+ ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
+ continue;
+ }
+ }
+#ifndef NDEBUG
+ assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
+ for (int j = 0; j < GGML_MAX_SRC; j++) {
+ if (node->src[j] != nullptr) {
+ assert(node->src[j]->buffer);
+ assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
+ ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
+ }
+ }
+#else
+ GGML_UNUSED(integrated);
+#endif // NDEBUG
+
+ bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
+ if (!ok) {
+ GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
+ }
+ GGML_ASSERT(ok);
+
+ if (!is_concurrent_event_active) {
+ try_launch_concurrent_event(node);
+ }
+ }
+ }
+
+#ifdef USE_CUDA_GRAPH
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+ if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
+ if (graph->graph != nullptr) {
+ CUDA_CHECK(cudaGraphDestroy(graph->graph));
+ graph->graph = nullptr;
+ }
+
+ CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
+ graph_evaluated_or_captured = true; // CUDA graph has been captured
+
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
+ if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
+ ggml_cuda_lock_cv.notify_all();
+ }
+ } else {
+ graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
+ }
+ }
+
+ if (use_cuda_graph) {
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+ if (graph->instance == nullptr) { // Create executable graph from captured graph.
+ CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
+ }
+ if (cuda_graph_update_required) { // Update graph executable
+ ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
+ }
+ // Launch graph
+ CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
+#else
+ GGML_UNUSED(graph_key);
+ graph_evaluated_or_captured = true;
+#endif // USE_CUDA_GRAPH
+ }
+}
+
+#ifdef USE_CUDA_GRAPH
+static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+
+ if (graph->graph == nullptr) {
+ if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
+ if (!graph->disable_due_to_gpu_arch) {
+ GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+ }
+ graph->disable_due_to_gpu_arch = true;
+ }
+ }
+
+ return graph->is_enabled();
+}
+#endif // USE_CUDA_GRAPH
+
+static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+
+ ggml_cuda_set_device(cuda_ctx->device);
+
+ bool use_cuda_graph = false;
+ bool cuda_graph_update_required = false;
+ const void * graph_key = nullptr;
+
+#ifdef USE_CUDA_GRAPH
+ graph_key = ggml_cuda_graph_get_key(cgraph);
+
+ use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
+
+ ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+ if (graph->is_enabled()) {
+ cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
+ use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);
+
+ graph->record_update(use_cuda_graph, cuda_graph_update_required);
+ }
+#endif // USE_CUDA_GRAPH
+
+ if (use_cuda_graph && cuda_graph_update_required) {
+ // Start CUDA graph capture
+ {
+ std::lock_guard<std::mutex> lock(ggml_cuda_lock);
+ ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
+ }
+
+ CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
+ }
+
+ ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
+
+ return GGML_STATUS_SUCCESS;
+}
+
+static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));
+}
+
+static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+
+ if (ggml_backend_is_cuda(backend)) {
+ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));
+ } else {
+#if 0
+ // untested
+ auto wait_fn = [](void * user_data) {
+ ggml_backend_event_t event = (ggml_backend_event_t)user_data;
+ ggml_backend_event_synchronize(event);
+ };
+
+ CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
+#endif
+ GGML_ABORT("fatal error");
+ }
+}
+
+static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+
+#ifdef USE_CUDA_GRAPH
+ const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+ const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
+#else
+ const bool use_cuda_graph = false;
+ GGML_UNUSED(cuda_ctx);
+ GGML_UNUSED(cgraph);
+#endif
+
+ static bool enable_graph_optimization = [] {
+ const char * env = getenv("GGML_CUDA_GRAPH_OPT");
+ return env != nullptr && atoi(env) == 1;
+ }();
+
+ if (!enable_graph_optimization) {
+ return;
+ }
+
+ ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
+ stream_context.reset();
+
+ if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
+ return;
+ }
+
+ // number of out-degrees for a particular node
+ std::unordered_map<const ggml_tensor *, int> fan_out;
+ // reverse mapping of node to index in the cgraph
+ std::unordered_map<const ggml_tensor *, int> node_indices;
+
+ const auto & is_noop = [](const ggml_tensor * node) -> bool {
+ return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
+ node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
+ };
+
+ const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
+ if (dst->src[s] == src) {
+ return true;
+ }
+ }
+ // implicit dependency if they view the same tensor
+ const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
+ const ggml_tensor * src2 = src->view_src ? src->view_src : src;
+ if (dst2 == src2) {
+ return true;
+ }
+ return false;
+ };
+
+ for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
+ const ggml_tensor * node = cgraph->nodes[node_idx];
+ node_indices[node] = node_idx;
+
+ if (is_noop(node)) {
+ continue;
+ }
+ for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+ const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
+ //TODO: check why nrows > 1 fails
+ if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
+ fan_out[src] += 1;
+ }
+ }
+ }
+
+ // Target Q, K, V for concurrency
+ // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
+ // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
+ // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
+ // 3. account for all branches from the fork to the join
+ // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
+ // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
+ // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
+
+ const int min_fan_out = 3;
+ const int max_fan_out = 3;
+
+ // store {fork_idx, join_idx}
+ std::vector<std::pair<int, int>> concurrent_node_ranges;
+
+ for (const auto & [root_node, count] : fan_out) {
+ if (count >= min_fan_out && count <= max_fan_out) {
+ const int root_node_idx = node_indices[root_node];
+
+ // only optimize for attn_norm
+ // TODO: make this more generic
+ if (!strstr(root_node->name, "attn_norm")) {
+ continue;
+ }
+
+ bool is_part_of_event = false;
+ for (const auto & [start, end] : concurrent_node_ranges) {
+ if (root_node_idx >= start && root_node_idx <= end) {
+ is_part_of_event = true;
+ }
+ }
+
+ if (is_part_of_event) {
+ continue;
+ }
+
+ std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
+ const ggml_tensor * node = cgraph->nodes[i];
+ if (!is_noop(node) && depends_on(node, root_node)) {
+ nodes_per_branch.push_back({ node });
+ }
+ }
+
+ GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
+
+ //find the join point
+ const ggml_tensor * join_node = nullptr;
+
+ const auto & belongs_to_branch = [&](const ggml_tensor * node,
+ const std::vector<const ggml_tensor *> & branch) -> bool {
+ for (const ggml_tensor * n : branch) {
+ if (depends_on(node, n)) {
+ return true;
+ }
+ }
+ return false;
+ };
+
+ for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
+ const ggml_tensor * curr_node = cgraph->nodes[i];
+
+ int num_joins = 0;
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+ if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
+ num_joins++;
+ }
+ }
+
+ if (num_joins >= 2) {
+ join_node = curr_node;
+ break;
+ }
+
+ bool found_branch = false;
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+ std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
+ if (belongs_to_branch(curr_node, branch_vec)) {
+ //continue accumulating
+ if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
+ branch_vec.push_back(curr_node);
+ }
+ found_branch = true;
+ }
+ }
+
+ if (!found_branch && is_noop(curr_node)) {
+ // we can put it in any branch because it will be ignored
+ nodes_per_branch[0].push_back({ curr_node });
+ }
+ }
+
+ if (join_node) {
+ //Create ggml_cuda_concurrent_event
+ ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
+ concurrent_event.join_node = join_node;
+
+ for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+ for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
+ concurrent_event.stream_mapping[n] = branch_idx + 1;
+ }
+ }
+
+ int fork_node_idx = node_indices[root_node];
+ int join_node_idx = node_indices[join_node];
+
+ int current_branch_idx = 0;
+ int current_node_idx = fork_node_idx + 1;
+ const int n_branches = nodes_per_branch.size();
+
+ int total_branch_nodes = 0;
+ for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
+ total_branch_nodes += branch_nodes.size();
+ }
+
+ // there are other nodes in the middle which are unaccounted for
+ // usually (cpy) nodes, then ignore this fork
+ if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
+ GGML_LOG_DEBUG(
+ "Skipping %s because the number of nodes in the middle is not equal to the total number of "
+ "branch nodes %d != %d\n",
+ root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
+ continue;
+ }
+
+ // Save the original order of nodes in this region before interleaving
+ // This is used later to restore grouping for fusion within streams
+ concurrent_event.original_order.reserve(total_branch_nodes);
+ for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
+ concurrent_event.original_order.push_back(cgraph->nodes[i]);
+ }
+
+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
+ GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
+ concurrent_events.emplace(root_node, std::move(concurrent_event));
+ GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
+ concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
+
+ // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
+ // example transformation:
+ // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
+ // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
+ while (current_node_idx < join_node_idx) {
+ std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
+
+ bool has_node = false;
+ for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
+ has_node |= branch_node.size() > 0;
+ }
+
+ GGML_ASSERT(has_node);
+
+ if (branch_nodes.empty()) {
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
+ continue;
+ }
+
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
+ current_node_idx++;
+ branch_nodes.erase(branch_nodes.begin());
+
+ // append all empty nodes
+ while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
+ cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
+ current_node_idx++;
+ branch_nodes.erase(branch_nodes.begin());
+ }
+
+ current_branch_idx = (current_branch_idx + 1) % n_branches;
+ }
+ }
+ }
+ }
+}
+
+static const ggml_backend_i ggml_backend_cuda_interface = {
+ /* .get_name = */ ggml_backend_cuda_get_name,
+ /* .free = */ ggml_backend_cuda_free,
+ /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async,
+ /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async,
+ /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async,
+ /* .synchronize = */ ggml_backend_cuda_synchronize,
+ /* .graph_plan_create = */ NULL,
+ /* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
+ /* .graph_plan_compute = */ NULL,
+ /* .graph_compute = */ ggml_backend_cuda_graph_compute,
+ /* .event_record = */ ggml_backend_cuda_event_record,
+ /* .event_wait = */ ggml_backend_cuda_event_wait,
+ /* .graph_optimize = */ ggml_backend_cuda_graph_optimize,
+};
+
+static ggml_guid_t ggml_backend_cuda_guid() {
+ static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
+ return &guid;
+}
+
+bool ggml_backend_is_cuda(ggml_backend_t backend) {
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
+}
+
+int ggml_backend_cuda_get_device_count() {
+ return ggml_cuda_info().device_count;
+}
+
+void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {
+ cudaDeviceProp prop;
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
+ snprintf(description, description_size, "%s", prop.name);
+}
+
+void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {
+ ggml_cuda_set_device(device);
+
+ CUDA_CHECK(cudaMemGetInfo(free, total));
+}
+
+bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
+ if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+ return false;
+ }
+
+#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
+ cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
+ if (err != cudaSuccess) {
+ // clear the error
+ (void)cudaGetLastError();
+
+ GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
+ size / 1024.0 / 1024.0, cudaGetErrorString(err));
+ return false;
+ }
+ return true;
+#else
+ GGML_UNUSED(buffer);
+ GGML_UNUSED(size);
+ return false;
+#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
+}
+
+void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
+ if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
+ return;
+ }
+
+ cudaError_t err = cudaHostUnregister(buffer);
+ if (err != cudaSuccess) {
+ // clear the error
+ (void)cudaGetLastError();
+ }
+}
+
+
+// backend device
+
+struct ggml_backend_cuda_device_context {
+ int device;
+ std::string name;
+ std::string description;
+ std::string pci_bus_id;
+ int op_offload_min_batch_size;
+};
+
+static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ return ctx->name.c_str();
+}
+
+static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ return ctx->description.c_str();
+}
+
+#if defined(__linux__)
+// Helper function to get available memory from /proc/meminfo for UMA systems
+static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
+ FILE * meminfo_file = nullptr;
+ // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
+ const size_t BUFFER_SIZE = 2048;
+ auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
+ size_t bytes_read = 0;
+ long huge_tlb_total_pages = -1;
+ long huge_tlb_free_pages = -1;
+ long huge_tlb_page_size = -1;
+
+ if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
+ return false;
+ }
+
+ meminfo_file = fopen("/proc/meminfo", "r");
+ if (meminfo_file == nullptr) {
+ GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
+ return false;
+ }
+
+ // Read file into buffer
+ bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
+ fclose(meminfo_file);
+
+ if (bytes_read == 0) {
+ GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
+ return false;
+ }
+ file_buffer[bytes_read] = '\0';
+
+ *available_memory_kb = -1;
+ *free_swap_kb = -1;
+
+ // Parse the file buffer line by line
+ char * line = file_buffer.get();
+ char * line_next;
+ while (line < file_buffer.get() + bytes_read) {
+ // Find the end of the current line
+ line_next = strchr(line, '\n');
+ if (line_next != nullptr) {
+ *line_next = '\0';
+ line_next++;
+ } else {
+ line_next = file_buffer.get() + bytes_read;
+ }
+
+ long value;
+ if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
+ *available_memory_kb = value;
+ } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
+ *free_swap_kb = value;
+ } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
+ huge_tlb_total_pages = value;
+ } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
+ huge_tlb_free_pages = value;
+ } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
+ huge_tlb_page_size = value;
+ }
+
+ line = line_next;
+ }
+
+ if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
+ *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
+
+ // Hugetlbfs pages are not swappable.
+ *free_swap_kb = 0;
+ }
+
+ GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
+ return true;
+}
+#endif // defined(__linux__)
+
+static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ ggml_cuda_set_device(ctx->device);
+ CUDA_CHECK(cudaMemGetInfo(free, total));
+
+// ref: https://github.com/ggml-org/llama.cpp/pull/17368
+#if defined(__linux__)
+ // Check if this is a UMA (Unified Memory Architecture) system
+ cudaDeviceProp prop;
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
+
+ // Check if UMA is explicitly enabled via environment variable
+ bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
+ bool is_uma = prop.integrated > 0 || uma_env;
+
+ if (is_uma) {
+ // For UMA systems (like DGX Spark), use system memory info
+ long available_memory_kb = 0;
+ long free_swap_kb = 0;
+
+ if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
+ *free = (size_t)available_memory_kb * 1024;
+ } else {
+ GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
+ }
+ }
+#endif // defined(__linux__)
+
+}
+
+static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
+ GGML_UNUSED(dev);
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
+}
+
+static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+
+ props->name = ggml_backend_cuda_device_get_name(dev);
+ props->description = ggml_backend_cuda_device_get_description(dev);
+ props->type = ggml_backend_cuda_device_get_type(dev);
+ props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
+ ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+ bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
+#ifdef GGML_CUDA_NO_PEER_COPY
+ bool events = false;
+#else
+ bool events = true;
+#endif
+
+ props->caps = {
+ /* .async = */ true,
+ /* .host_buffer = */ host_buffer,
+ /* .buffer_from_host_ptr = */ false,
+ /* .events = */ events,
+ };
+}
+
+static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {
+ GGML_UNUSED(params);
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ return ggml_backend_cuda_init(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ return ggml_backend_cuda_buffer_type(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) {
+ GGML_UNUSED(dev);
+ return ggml_backend_cuda_host_buffer_type();
+}
+
+// TODO: move these functions here
+static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
+
+ // split buffers can only be used with GGML_OP_MUL_MAT
+ if (op->op != GGML_OP_MUL_MAT) {
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) {
+ return false;
+ }
+ }
+ }
+
+ // check if all the sources are allocated on this device
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
+ if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) {
+ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context;
+ if (buft_ctx->device != dev_ctx->device) {
+ return false;
+ }
+ }
+ }
+
+ switch (op->op) {
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(op)) {
+ case GGML_UNARY_OP_ABS:
+ case GGML_UNARY_OP_SGN:
+ case GGML_UNARY_OP_NEG:
+ case GGML_UNARY_OP_STEP:
+ case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_SIGMOID:
+ case GGML_UNARY_OP_HARDSIGMOID:
+ case GGML_UNARY_OP_HARDSWISH:
+ case GGML_UNARY_OP_GELU_ERF:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_TANH:
+ case GGML_UNARY_OP_EXP:
+ case GGML_UNARY_OP_EXPM1:
+ case GGML_UNARY_OP_SOFTPLUS:
+ case GGML_UNARY_OP_ELU:
+ case GGML_UNARY_OP_XIELU:
+ case GGML_UNARY_OP_FLOOR:
+ case GGML_UNARY_OP_CEIL:
+ case GGML_UNARY_OP_ROUND:
+ case GGML_UNARY_OP_TRUNC:
+ return ggml_is_contiguous(op->src[0]);
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_GLU:
+ switch (ggml_get_glu_op(op)) {
+ case GGML_GLU_OP_REGLU:
+ case GGML_GLU_OP_GEGLU:
+ case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
+ case GGML_GLU_OP_GEGLU_ERF:
+ case GGML_GLU_OP_GEGLU_QUICK:
+ return ggml_is_contiguous_1(op->src[0]);
+ default:
+ return false;
+ }
+ break;
+ case GGML_OP_MUL_MAT:
+ case GGML_OP_MUL_MAT_ID:
+ {
+ struct ggml_tensor * a = op->src[0];
+ struct ggml_tensor * b = op->src[1];
+ if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {
+ if (a->ne[2] > 1 || a->ne[3] > 1) {
+ return false;
+ }
+ // for small weight matrices the active device can end up without any rows, don't use row split in those cases
+ // this avoids some edge cases (and the performance would not be good anyways)
+ ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;
+ int64_t row_low;
+ int64_t row_high;
+ get_row_split(&row_low, &row_high, a, buft_ctx->tensor_split, dev_ctx->device);
+ if (row_low == row_high) {
+ return false;
+ }
+ }
+ if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
+ return false;
+ }
+#ifdef GGML_USE_MUSA
+ const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
+ if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
+ if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
+ a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
+ return false;
+ }
+ if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
+ a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
+ return false;
+ }
+ }
+#endif // GGML_USE_MUSA
+ switch (a->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_Q8_K:
+ case GGML_TYPE_IQ1_M:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_BF16:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_OUT_PROD:
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_BF16:
+ case GGML_TYPE_I32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_GET_ROWS_BACK:
+ {
+ return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
+ } break;
+ case GGML_OP_SET_ROWS:
+ {
+ return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
+ op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
+ op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
+ op->src[0]->type == GGML_TYPE_F32 &&
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
+ } break;
+ case GGML_OP_SET:
+ {
+ const ggml_type t = op->type;
+ return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
+ t == op->src[0]->type &&
+ t == op->src[1]->type;
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
+ (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
+ ) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
+ return true;
+ }
+ if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
+ return true;
+ }
+ return false;
+ } break;
+ case GGML_OP_DUP:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+ } break;
+ case GGML_OP_ARGMAX:
+ case GGML_OP_COUNT_EQUAL:
+ {
+ return true;
+ } break;
+ case GGML_OP_REPEAT:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+ } break;
+ case GGML_OP_REPEAT_BACK:
+ return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
+ case GGML_OP_CONCAT:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
+ } break;
+ case GGML_OP_CONV_TRANSPOSE_1D:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ return false;
+ } break;
+ case GGML_OP_SILU_BACK:
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_NORM:
+ case GGML_OP_RMS_NORM:
+ case GGML_OP_L2_NORM:
+ return true;
+ case GGML_OP_RMS_NORM_BACK:
+ return ggml_is_contiguous(op->src[0]);
+ break;
+ case GGML_OP_NONE:
+ case GGML_OP_RESHAPE:
+ case GGML_OP_VIEW:
+ case GGML_OP_PERMUTE:
+ case GGML_OP_TRANSPOSE:
+ case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
+ case GGML_OP_ADD1:
+ case GGML_OP_SUB:
+ case GGML_OP_MUL:
+ case GGML_OP_DIV:
+ case GGML_OP_SCALE:
+ case GGML_OP_SQR:
+ case GGML_OP_SQRT:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
+ case GGML_OP_CLAMP:
+ case GGML_OP_LOG:
+ return true;
+ case GGML_OP_SSM_SCAN: {
+ if (op->src[3]->ne[0] == 1) {
+ // Mamba2
+ // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
+ return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
+ } else {
+ // Mamba
+ // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
+ return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
+ }
+ }
+ case GGML_OP_SSM_CONV: {
+ // assumes d_inner % threads == 0
+ return op->src[0]->ne[1] % 128 == 0;
+ }
+ case GGML_OP_CONT:
+ return true;
+ case GGML_OP_DIAG_MASK_INF:
+ return true;
+ case GGML_OP_SOFT_MAX:
+ return true;
+ case GGML_OP_SOFT_MAX_BACK: {
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
+ return max_bias == 0.0f;
+ }
+ case GGML_OP_ROLL:
+ if(op->src[0]->type == GGML_TYPE_F32) {
+ return true;
+ }
+ return false;
+ case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK: {
+ return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
+ }
+ case GGML_OP_IM2COL:
+ case GGML_OP_IM2COL_3D:
+ case GGML_OP_CONV_2D:
+ case GGML_OP_CONV_2D_DW:
+ case GGML_OP_CONV_TRANSPOSE_2D:
+ case GGML_OP_POOL_2D:
+ case GGML_OP_ACC:
+ return true;
+ case GGML_OP_SUM:
+ return ggml_is_contiguous_rows(op->src[0]);
+ case GGML_OP_TOP_K:
+ case GGML_OP_ARGSORT:
+#ifndef GGML_CUDA_USE_CUB
+ return op->src[0]->ne[0] <= 1024;
+#else
+ return true;
+#endif
+ case GGML_OP_SUM_ROWS:
+ case GGML_OP_MEAN:
+ case GGML_OP_GROUP_NORM:
+ return ggml_is_contiguous(op->src[0]);
+ case GGML_OP_PAD:
+ return true;
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD_REFLECT_1D:
+ case GGML_OP_ARANGE:
+ case GGML_OP_TIMESTEP_EMBEDDING:
+ case GGML_OP_LEAKY_RELU:
+ case GGML_OP_RWKV_WKV6:
+ case GGML_OP_GATED_LINEAR_ATTN:
+ case GGML_OP_RWKV_WKV7:
+ return true;
+ case GGML_OP_FLASH_ATTN_EXT:
+ return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
+ case GGML_OP_CROSS_ENTROPY_LOSS:
+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
+ case GGML_OP_OPT_STEP_ADAMW:
+ case GGML_OP_OPT_STEP_SGD:
+ case GGML_OP_FILL:
+ case GGML_OP_CUMSUM:
+ case GGML_OP_TRI:
+ case GGML_OP_DIAG:
+ case GGML_OP_SOLVE_TRI:
+ return true;
+
+ default:
+ return false;
+ }
+}
+
+static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
+ const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
+ return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
+}
+
+static int64_t get_op_batch_size(const ggml_tensor * op) {
+ switch (op->op) {
+ case GGML_OP_GET_ROWS:
+ return 0;
+ case GGML_OP_MUL_MAT:
+ return op->ne[1];
+ case GGML_OP_MUL_MAT_ID:
+ case GGML_OP_ROPE:
+ case GGML_OP_ROPE_BACK:
+ return op->ne[2];
+ default:
+ return ggml_nrows(op);
+ }
+}
+
+static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
+
+ return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
+}
+
+static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
+#ifdef GGML_CUDA_NO_PEER_COPY
+ return nullptr;
+#else
+ ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context;
+
+ ggml_cuda_set_device(dev_ctx->device);
+
+ cudaEvent_t event;
+ CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
+
+ return new ggml_backend_event {
+ /* .device = */ dev,
+ /* .context = */ event,
+ };
+#endif
+}
+
+static void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+ GGML_UNUSED(dev);
+
+ CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
+ delete event;
+}
+
+static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+ GGML_UNUSED(dev);
+ CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
+}
+
+static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
+ /* .get_name = */ ggml_backend_cuda_device_get_name,
+ /* .get_description = */ ggml_backend_cuda_device_get_description,
+ /* .get_memory = */ ggml_backend_cuda_device_get_memory,
+ /* .get_type = */ ggml_backend_cuda_device_get_type,
+ /* .get_props = */ ggml_backend_cuda_device_get_props,
+ /* .init_backend = */ ggml_backend_cuda_device_init_backend,
+ /* .get_buffer_type = */ ggml_backend_cuda_device_get_buffer_type,
+ /* .get_host_buffer_type = */ ggml_backend_cuda_device_get_host_buffer_type,
+ /* .buffer_from_host_ptr = */ NULL,
+ /* .supports_op = */ ggml_backend_cuda_device_supports_op,
+ /* .supports_buft = */ ggml_backend_cuda_device_supports_buft,
+ /* .offload_op = */ ggml_backend_cuda_device_offload_op,
+ /* .event_new = */ ggml_backend_cuda_device_event_new,
+ /* .event_free = */ ggml_backend_cuda_device_event_free,
+ /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
+};
+
+// backend reg
+
+struct ggml_backend_cuda_reg_context {
+ std::vector<ggml_backend_dev_t> devices;
+};
+
+static const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) {
+ GGML_UNUSED(reg);
+ return GGML_CUDA_NAME;
+}
+
+static size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) {
+ ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
+ return ctx->devices.size();
+}
+
+static ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+ ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
+ GGML_ASSERT(index < ctx->devices.size());
+ return ctx->devices[index];
+}
+
+static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t reg) {
+ static std::vector<ggml_backend_feature> features = []() {
+ std::vector<ggml_backend_feature> features;
+ #define _STRINGIFY(...) #__VA_ARGS__
+ #define STRINGIFY(...) _STRINGIFY(__VA_ARGS__)
+
+ #ifdef __CUDA_ARCH_LIST__
+ features.push_back({ "ARCHS", STRINGIFY(__CUDA_ARCH_LIST__) });
+ #endif
+
+ #ifdef GGML_CUDA_FORCE_MMQ
+ features.push_back({ "FORCE_MMQ", "1" });
+ #endif
+
+ #ifdef GGML_CUDA_FORCE_CUBLAS
+ features.push_back({ "FORCE_CUBLAS", "1" });
+ #endif
+
+ #ifndef GGML_USE_VMM
+ features.push_back({ "NO_VMM", "1" });
+ #endif
+
+ #ifdef GGML_CUDA_NO_PEER_COPY
+ features.push_back({ "NO_PEER_COPY", "1" });
+ #endif
+
+ #ifdef GGML_CUDA_USE_GRAPHS
+ features.push_back({ "USE_GRAPHS", "1" });
+ #endif
+
+ #ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE
+ features.push_back({ "PEER_MAX_BATCH_SIZE", STRINGIFY(GGML_CUDA_PEER_MAX_BATCH_SIZE) });
+ #endif
+
+ #ifdef GGML_CUDA_FA_ALL_QUANTS
+ features.push_back({ "FA_ALL_QUANTS", "1" });
+ #endif
+
+ {
+ const auto & info = ggml_cuda_info();
+ for (int id = 0; id < info.device_count; ++id) {
+ if (blackwell_mma_available(info.devices[id].cc)) {
+ features.push_back({ "BLACKWELL_NATIVE_FP4", "1"});
+ break;
+ }
+ }
+ }
+
+ #undef _STRINGIFY
+ #undef STRINGIFY
+
+ features.push_back({ nullptr, nullptr });
+
+ return features;
+ }();
+
+ return features.data();
+
+ GGML_UNUSED(reg);
+}
+
+static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+ GGML_UNUSED(reg);
+ if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
+ return (void *)ggml_backend_cuda_split_buffer_type;
+ }
+ if (strcmp(name, "ggml_backend_register_host_buffer") == 0) {
+ return (void *)ggml_backend_cuda_register_host_buffer;
+ }
+ if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) {
+ return (void *)ggml_backend_cuda_unregister_host_buffer;
+ }
+ if (strcmp(name, "ggml_backend_get_features") == 0) {
+ return (void *)ggml_backend_cuda_get_features;
+ }
+ return nullptr;
+}
+
+static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {
+ /* .get_name = */ ggml_backend_cuda_reg_get_name,
+ /* .get_device_count = */ ggml_backend_cuda_reg_get_device_count,
+ /* .get_device = */ ggml_backend_cuda_reg_get_device,
+ /* .get_proc_address = */ ggml_backend_cuda_reg_get_proc_address,
+};
+
+// backend registry
+ggml_backend_reg_t ggml_backend_cuda_reg() {
+ static ggml_backend_reg reg;
+ static bool initialized = false;
+
+ {
+ static std::mutex mutex;
+ std::lock_guard<std::mutex> lock(mutex);
+ if (!initialized) {
+ ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
+ const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
+
+ for (int i = 0; i < ggml_cuda_info().device_count; i++) {
+ ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
+ dev_ctx->device = i;
+ dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
+
+ cudaDeviceProp prop;
+ CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
+ dev_ctx->description = prop.name;
+
+ char pci_bus_id[16] = {};
+ snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
+ dev_ctx->pci_bus_id = pci_bus_id;
+ dev_ctx->op_offload_min_batch_size = min_batch_size;
+
+ ggml_backend_dev_t dev = new ggml_backend_device {
+ /* .iface = */ ggml_backend_cuda_device_interface,
+ /* .reg = */ &reg,
+ /* .context = */ dev_ctx
+ };
+ ctx->devices.push_back(dev);
+ }
+
+ reg = ggml_backend_reg {
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
+ /* .iface = */ ggml_backend_cuda_reg_interface,
+ /* .context = */ ctx
+ };
+ }
+
+ initialized = true;
+ }
+
+ return &reg;
+}
+
+ggml_backend_t ggml_backend_cuda_init(int device) {
+ if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
+ GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device);
+ return nullptr;
+ }
+
+ ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
+ if (ctx == nullptr) {
+ GGML_LOG_ERROR("%s: failed to allocate context\n", __func__);
+ return nullptr;
+ }
+
+ ggml_backend_t cuda_backend = new ggml_backend {
+ /* .guid = */ ggml_backend_cuda_guid(),
+ /* .iface = */ ggml_backend_cuda_interface,
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
+ /* .context = */ ctx,
+ };
+
+ return cuda_backend;
+}
+
+GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg)
diff --git a/llama.cpp/ggml/src/ggml-cuda/gla.cu b/llama.cpp/ggml/src/ggml-cuda/gla.cu
new file mode 100644
index 0000000..f7d615a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/gla.cu
@@ -0,0 +1,93 @@
+#include "common.cuh"
+#include "gla.cuh"
+
+template<int HEAD_SIZE>
+static __global__ void gated_linear_attn_f32(const int B, const int T, const int C, const int H, const float scale,
+ const float * k, const float * v, const float * r, const float * td, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = HEAD_SIZE;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _k[head_size], _r[head_size], _td[head_size];
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+ }
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ __syncthreads();
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4 & k = (float4 &)(_k[j]);
+ const float4 & r = (float4 &)(_r[j]);
+ const float4 & td = (float4 &)(_td[j]);
+ float4 & s = (float4 &)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ s.x = s.x * td.x + kv.x;
+ s.y = s.y * td.y + kv.y;
+ s.z = s.z * td.z + kv.z;
+ s.w = s.w * td.w + kv.w;
+
+ y += r.x * s.x;
+ y += r.y * s.y;
+ y += r.z * s.z;
+ y += r.w * s.w;
+ }
+ dst[t] = y * scale;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+ }
+}
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * k_d = (const float *)dst->src[0]->data;
+ const float * v_d = (const float *)dst->src[1]->data;
+ const float * r_d = (const float *)dst->src[2]->data;
+ const float * td_d = (const float *)dst->src[3]->data;
+ const float * s_d = (const float *)dst->src[4]->data;
+
+ const int64_t B = dst->src[4]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ float scale;
+ memcpy(&scale, (float*)dst->op_params, sizeof(float));
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[4]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == 64 || C / H == 128);
+
+
+ if (C / H == 64) {
+ gated_linear_attn_f32<64><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+ } else {
+ gated_linear_attn_f32<128><<<B * H, C / H, 0, stream>>>(B, T, C, H, scale, k_d, v_d, r_d, td_d, s_d, dst_d);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/gla.cuh b/llama.cpp/ggml/src/ggml-cuda/gla.cuh
new file mode 100644
index 0000000..2c82ad7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/gla.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_gated_linear_attn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/im2col.cu b/llama.cpp/ggml/src/ggml-cuda/im2col.cu
new file mode 100644
index 0000000..56dc054
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/im2col.cu
@@ -0,0 +1,264 @@
+#include "im2col.cuh"
+
+#define MAX_GRIDDIM_Z 65535
+
+template <typename T>
+static __global__ void im2col_kernel(
+ const float * x, T * dst,
+ int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH,
+ int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW,
+ int s0, int s1, int p0, int p1, int d0, int d1) {
+ const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
+ if (i >= IC_KH_KW) {
+ return;
+ }
+
+ const int64_t iic = i / (KH_KW);
+ const int64_t rem = i - iic * KH_KW;
+ const int64_t ikh = rem / KW;
+ const int64_t ikw = rem - ikh * KW;
+
+ const int64_t iow = blockIdx.y;
+ for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) {
+ const int64_t in = iz / OH;
+ const int64_t ioh = iz - in * OH;
+
+ const int64_t iiw = iow * s0 + ikw * d0 - p0;
+ const int64_t iih = ioh * s1 + ikh * d1 - p1;
+
+ const int64_t offset_dst =
+ ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw;
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+ dst[offset_dst] = 0.0f;
+ } else {
+ const int64_t offset_src = iic * IC_IH_IW + in * IH_IW;
+ dst[offset_dst] = x[offset_src + iih * IW + iiw];
+ }
+ }
+
+ GGML_UNUSED(IC);
+ GGML_UNUSED(KH);
+}
+
+// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
+template <typename T>
+static void im2col_cuda(const float * x, T* dst,
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+ int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+ const int64_t IC_KH_KW = IC * KH * KW;
+ const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+ const int64_t N_OH = N * OH;
+ const int64_t KH_KW = KW*KH;
+ dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z));
+ im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH,
+ IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW,
+ s0, s1, p0, p1, d0, d1);
+}
+
+static void im2col_cuda_f16(const float * x, half * dst,
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+ int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+ im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+}
+
+static void im2col_cuda_f32(const float * x, float * dst,
+ int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
+ int64_t N, int64_t IC_IH_IW, int64_t IH_IW,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+
+ im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+}
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+ const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
+ const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+ const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
+ const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+ const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
+
+ const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+
+ const int64_t IC = src1->ne[is_2D ? 2 : 1];
+ const int64_t IH = is_2D ? src1->ne[1] : 1;
+ const int64_t IW = src1->ne[0];
+
+ const int64_t KH = is_2D ? src0->ne[1] : 1;
+ const int64_t KW = src0->ne[0];
+
+ const int64_t OH = is_2D ? dst->ne[2] : 1;
+ const int64_t OW = dst->ne[1];
+
+ const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+ const int64_t N = src1->ne[is_2D ? 3 : 2];
+ const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
+
+ if(dst->type == GGML_TYPE_F16) {
+ im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+ } else {
+ im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
+ }
+}
+
+// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+template <typename T>
+static __global__ void im2col_3d_kernel(
+ const float * src, T * dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
+ int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
+ int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
+ int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
+ const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
+ if (i >= IC_KD_KH_KW) {
+ return;
+ }
+ GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
+ GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
+
+ const int64_t iic = i / KD_KH_KW;
+ const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
+ const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
+ const int64_t ikw = i % KW;
+
+ const int64_t iow = blockIdx.y;
+ for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
+ const int64_t in = iz / OD_OH;
+ const int64_t iod = (iz - in*OD_OH) / OH;
+ const int64_t ioh = iz % OH;
+
+ const int64_t iiw = iow * s0 + ikw * d0 - p0;
+ const int64_t iih = ioh * s1 + ikh * d1 - p1;
+ const int64_t iid = iod * s2 + ikd * d2 - p2;
+
+ const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
+
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
+ dst[offset_dst] = 0.0f;
+ } else {
+ const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
+ dst[offset_dst] = src[offset_src];
+ }
+ }
+}
+
+// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
+template <typename T>
+static void im2col_3d_cuda(const float * src, T* dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+ const int64_t OH_OW = OH*OW;
+ const int64_t KD_KH_KW = KD*KH*KW;
+ const int64_t ID_IH_IW = ID*IH*IW;
+ const int64_t KH_KW = KH*KW;
+ const int64_t IH_IW = IH*IW;
+ const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
+ const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
+ const int64_t N_OD_OH = N*OD*OH;
+ const int64_t OD_OH = OD*OH;
+ const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
+ const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
+ const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
+ const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
+ const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+ dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
+ im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+ OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
+ IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
+ OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
+ stride_q, stride_z, stride_y, stride_x,
+ s0, s1, s2, p0, p1, p2, d0, d1, d2);
+}
+
+static void im2col_3d_cuda_f16(const float * src, half * dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+
+ im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+ stride_q, stride_z, stride_y, stride_x,
+ s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+}
+
+static void im2col_3d_cuda_f32(const float * src, float * dst,
+ int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
+ int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+ int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
+ int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
+
+ im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+ stride_q, stride_z, stride_y, stride_x,
+ s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+}
+
+void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
+ const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
+ const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
+ const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
+ const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
+ const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
+ const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
+ const int32_t IC = ((const int32_t *)(dst->op_params))[9];
+
+ const int64_t N = ne13 / IC;
+ const int64_t ID = ne12;
+ const int64_t IH = ne11;
+ const int64_t IW = ne10;
+
+ const int64_t OC = ne03 / IC;
+ const int64_t KD = ne02;
+ const int64_t KH = ne01;
+ const int64_t KW = ne00;
+
+ const int64_t OD = ne3 / N;
+ const int64_t OH = ne2;
+ const int64_t OW = ne1;
+
+ const size_t es = ggml_element_size(src1);
+ const int64_t stride_x = src1->nb[0] / es;
+ const int64_t stride_y = src1->nb[1] / es;
+ const int64_t stride_z = src1->nb[2] / es;
+ const int64_t stride_q = src1->nb[3] / es;
+
+ if(dst->type == GGML_TYPE_F16) {
+ im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+ stride_q, stride_z, stride_y, stride_x,
+ s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+ } else {
+ im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+ stride_q, stride_z, stride_y, stride_x,
+ s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/im2col.cuh b/llama.cpp/ggml/src/ggml-cuda/im2col.cuh
new file mode 100644
index 0000000..2da1223
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/im2col.cuh
@@ -0,0 +1,6 @@
+#include "common.cuh"
+
+#define CUDA_IM2COL_BLOCK_SIZE 256
+
+void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/mean.cu b/llama.cpp/ggml/src/ggml-cuda/mean.cu
new file mode 100644
index 0000000..49af538
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mean.cu
@@ -0,0 +1,75 @@
+#include "mean.cuh"
+#include "reduce_rows.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+#include <cub/cub.cuh>
+using namespace cub;
+#endif // GGML_CUDA_USE_CUB
+
+template <typename T> __global__ void divide_by_count(T * result, size_t count) {
+ *result /= static_cast<T>(count);
+}
+
+void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+// Special case for reducing vectors
+#ifdef GGML_CUDA_USE_CUB
+#ifdef USE_CUDA_GRAPH
+ cudaStreamCaptureStatus iscapturing;
+ CUDA_CHECK(cudaStreamIsCapturing(stream, &iscapturing));
+#endif // USE_CUDA_GRAPH
+ if ((nrows == 1) &&
+#ifdef USE_CUDA_GRAPH
+ // Determine if CUDA graphs are effectively disabled for this context
+ // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
+ (((ncols > 65536) &&
+ (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+ ctx.any_cuda_graph_enabled())) ||
+ // CUDA graphs are enabled - use lower threshold
+ ((ncols > 32768) &&
+ !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+ ctx.any_cuda_graph_enabled())))) {
+#else
+ (ncols > 65536)) {
+#endif // USE_CUDA_GRAPH
+ // Single row - use device-wide reduction
+ size_t tmp_size = 0;
+ ggml_cuda_pool & pool = ctx.pool();
+
+ DeviceReduce::Sum(nullptr, tmp_size, src0_d, dst_d, ncols, stream);
+
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
+ DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, src0_d, dst_d, ncols, stream);
+
+ // Divide by ncols
+ divide_by_count<float><<<1, 1, 0, stream>>>(dst_d, ncols);
+ return;
+ }
+#endif // GGML_CUDA_USE_CUB
+
+ const dim3 block_nums(nrows, 1, 1);
+
+ const int id = ggml_cuda_get_device();
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+
+ // Heuristic for block size selection to optimize occupancy.
+ // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132
+ if ((nrows / nsm) < 2) {
+ const dim3 block_dims(512, 1, 1);
+ reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+ } else {
+ const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+ reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/mean.cuh b/llama.cpp/ggml/src/ggml-cuda/mean.cuh
new file mode 100644
index 0000000..2b9b104
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mean.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/mma.cuh b/llama.cpp/ggml/src/ggml-cuda/mma.cuh
new file mode 100644
index 0000000..dd45d6c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mma.cuh
@@ -0,0 +1,1381 @@
+#pragma once
+// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
+// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
+// The documentation for the PTX instructions can be found under:
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
+//
+// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
+// A is a row-major matrix with shape M x K.
+// B is a column-major matrix with shape K x N.
+// C is a column-major matrix with shape M x N.
+// A, B, and C are represented using the same fundamental data type: a row-major matrix with I rows and J columns.
+// Note that J is measured in physical 32 bit elements instead of logical elements.
+// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
+// All matrix tiles have ne physical 32 bit elements per warp.
+//
+// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
+// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.
+
+#include "common.cuh"
+
+// On Volta each warp is doing 4 8x8 mma operations in parallel.
+// The basic memory layout for a 32x8 output tile is to stack 4 input tiles in I direction and to mirror the B tile.
+// However, the i indices in this file are by default permuted to simplify the index calculations.
+// #define GGML_CUDA_MMA_NO_VOLTA_PERM
+
+#if CUDART_VERSION >= 11080
+
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+ int ret = 0;
+
+#ifdef TURING_MMA_AVAILABLE
+ asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
+ : "=r"(ret) : "r"(x));
+#else
+ GGML_UNUSED(x);
+ NO_DEVICE_CODE;
+#endif // defined(TURING_MMA_AVAILABLE)
+ return ret;
+}
+
+#else
+
+static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
+ // Imagine transposing row-major matrix to column-major matrix.
+ const int src_i_low = 2 * (threadIdx.x % 4);
+ const int src_i_high = src_i_low + 1;
+ const int src_j = threadIdx.x / 4;
+
+ const int src_laneid_low = src_i_low * 4 + src_j / 2;
+ const int src_laneid_high = src_i_high * 4 + src_j / 2;
+
+ const int shift_low = ((src_j + 0) % 2) * 16;
+ const int shift_high = ((src_j + 1) % 2) * 16;
+
+ const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
+ const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
+
+ return ret_low | ret_high;
+}
+
+#endif // CUDART_VERSION >= 11080
+
+static __device__ __forceinline__ half2 ggml_cuda_movmatrix(const half2 x) {
+ half2 ret;
+ *((int *) &ret) = ggml_cuda_movmatrix(*((const int *) &x));
+ return ret;
+}
+
+namespace ggml_cuda_mma {
+
+ // Some architectures like Volta or CDNA3 perform multiple matrix multiplications per warp in parallel,
+ // effectively the warp is being split into subgroups of threads that each perform a single mma instruction.
+ // In those cases the data can be split in different ways across the warp.
+ enum data_layout {
+ // By default the data uses the I direction as its major dimension and the J direction as its minor dimension.
+ // For the A/C matrices this means I major == row major, J major == column major.
+ // For the B matrix this means I major == column major, J major == row major.
+ // MIRRORED == Each data value is held exactly once per thread subgroup.
+ DATA_LAYOUT_I_MAJOR = 0, // Always used for Turing, Ampere, Ada Lovelace, consumer Blackwell, matrix A&B for RDNA4 and CDNA.
+ DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3.
+ DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3.
+ DATA_LAYOUT_J_MAJOR_MIRRORED = 30,
+ };
+ // Implemented mma combinations are:
+ // - (I_MAJOR, I_MAJOR) -> I_MAJOR
+ // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR
+ // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR
+
+ static constexpr bool is_i_major(const data_layout dl) {
+ return dl == DATA_LAYOUT_I_MAJOR ||
+ dl == DATA_LAYOUT_I_MAJOR_MIRRORED;
+ }
+
+ static constexpr __device__ data_layout get_input_data_layout() {
+#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ return DATA_LAYOUT_I_MAJOR_MIRRORED;
+#else
+ return DATA_LAYOUT_I_MAJOR;
+#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ }
+
+ template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR>
+ struct tile {};
+
+ template <int I_, int J_, typename T>
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
+
+#if defined(AMD_MFMA_AVAILABLE)
+ static constexpr int ne = I * J / 64;
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ if (I == 64 && J == 2) return true;
+ if (I == 16 && J == 8) return true;
+ if (I == 32 && J == 4) return true;
+ if (I == 16 && J == 16) return true;
+ if (I == 32 && J == 32) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
+ return threadIdx.x % 16;
+ } else if constexpr (I == 16 && J == 8) {
+ return threadIdx.x % 16;
+ } else if constexpr (I == 32 && J == 4) {
+ return threadIdx.x % 32;
+ } else if constexpr (I == 16 && J == 16) {
+ return threadIdx.x % 16;
+ } else if constexpr (I == 32 && J == 32) {
+ return threadIdx.x % 32;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
+ return (2 * ((threadIdx.x / 16) % 2) + l);
+ } else if constexpr (I == 16 && J == 8) {
+ return 2 * (threadIdx.x / 16) + l;
+ } else if constexpr (I == 32 && J == 4) {
+ return 2 * (threadIdx.x / 32) + l;
+ } else if constexpr (I == 16 && J == 16) {
+ return 4 * (threadIdx.x / 16) + l;
+ } else if constexpr (I == 32 && J == 32) {
+ return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ static constexpr int ne = I * J / 32;
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ if (I == 32 && J == 8) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 32 && J == 8) {
+#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (l & 2) + (threadIdx.x % 2);
+#else
+ return (l & 2) + (threadIdx.x & ~2);
+#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 32 && J == 8) {
+ return (threadIdx.x & 2) + (l & (4 + 1));
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE)
+ static constexpr int ne = I * J / 32;
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ if (I == 16 && J == 16) return true;
+ if (I == 16 && J == 8) return true;
+ if (I == 16 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (supported()) {
+ return threadIdx.x % 16;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 16 && J == 16) {
+#if defined(RDNA3)
+ if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) {
+ // matrix C
+ return 2 * l + (threadIdx.x / 16);
+ } else {
+ // matrix A&B
+ return l;
+ }
+#else
+ // matrix C is the transposed matrix A&B on RDNA4
+ return ne * (threadIdx.x / 16) + l;
+#endif // defined(RDNA3)
+ } else if constexpr (I == 16 && J == 8) {
+ // mmq input for RDNA4
+ return ne * (threadIdx.x / 16) + l;
+ } else if constexpr (I == 16 && J == 4) {
+ return ne * (threadIdx.x / 16) + l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#else
+ static constexpr int ne = I * J / 32;
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ if (I == 8 && J == 4) return true;
+ if (I == 8 && J == 8) return true;
+ if (I == 16 && J == 8) return true;
+ if (I == 16 && J == 16) return true;
+ if (I == 32 && J == 8) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return threadIdx.x / 4;
+ } else if constexpr (I == 8 && J == 8) {
+ return threadIdx.x / 4;
+ } else if constexpr (I == 16 && J == 8) {
+ return ((l / 2) * 8) + (threadIdx.x / 4);
+ } else if constexpr (I == 16 && J == 16) {
+ return (((l / 2) % 2) * 8) + (threadIdx.x / 4);
+ } else if constexpr (I == 32 && J == 8) {
+ return tile<16, 8, T>::get_i(l); // Memory layout simply repeated with same pattern in i direction.
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return threadIdx.x % 4;
+ } else if constexpr (I == 8 && J == 8) {
+ return (l * 4) + (threadIdx.x % 4);
+ } else if constexpr (I == 16 && J == 8) {
+ return ((threadIdx.x % 4) * 2) + (l % 2);
+ } else if constexpr (I == 16 && J == 16) {
+ return ((l / 4) * 8) + ((threadIdx.x % 4) * 2) + (l % 2);
+ } else if constexpr (I == 32 && J == 8) {
+ return tile<16, 8, T>::get_j(l); // Memory layout simply repeated with same pattern in i direction.
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#endif // defined(GGML_USE_HIP)
+ };
+
+ template <int I_, int J_>
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
+
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ static constexpr int ne = I * J / WARP_SIZE;
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 32 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 32 && J == 4) {
+#ifdef GGML_CUDA_MMA_NO_VOLTA_PERM
+ return (((threadIdx.x % 16) / 4) * 8) + ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
+#else
+ return threadIdx.x;
+#endif // GGML_CUDA_MMA_NO_VOLTA_PERM
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 32 && J == 4) {
+ return l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE)
+ static constexpr int ne = I * J / 32;
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 16 && J == 8) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 16 && J == 8) {
+ return threadIdx.x % 16;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 16 && J == 8) {
+ return ne * (threadIdx.x / 16) + l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#elif defined(AMD_MFMA_AVAILABLE)
+ static constexpr int ne = I * J / 64;
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 16 && J == 8) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 16 && J == 8) {
+ return threadIdx.x % 16;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 16 && J == 8) {
+ return ne * (threadIdx.x / 16) + l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#else
+ static constexpr int ne = I * J / WARP_SIZE;
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 8 && J == 4) return true;
+ if (I == 8 && J == 8) return true;
+ if (I == 16 && J == 8) return true;
+ if (I == 16 && J == 16) return true;
+ if (I == 32 && J == 8) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 8 && J == 8) {
+ return threadIdx.x / 4;
+ } else if constexpr (I == 16 && J == 4) {
+ return (l * 8) + (threadIdx.x / 4);
+ } else if constexpr (I == 16 && J == 8) {
+ return ((l % 2) * 8) + (threadIdx.x / 4);
+ } else if constexpr (I == 32 && J == 8) {
+ return ((l / 4) * 16) + ((l % 2) * 8) + (threadIdx.x / 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 8) {
+ return (l * 4) + (threadIdx.x % 4);
+ } else if constexpr (I == 16 && J == 4) {
+ return threadIdx.x % 4;
+ } else if constexpr (I == 16 && J == 8) {
+ return ((l / 2) * 4) + (threadIdx.x % 4);
+ } else if constexpr (I == 32 && J == 8) {
+ return ((l & 2) * 2) + (threadIdx.x % 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ };
+
+ template <int I_, int J_>
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
+
+#if defined(AMD_WMMA_AVAILABLE)
+ static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
+ }
+#elif defined(AMD_MFMA_AVAILABLE)
+ static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
+ }
+#else
+ static constexpr int ne = I * J / WARP_SIZE;
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 8 && J == 8) return true;
+ if (I == 16 && J == 4) return true;
+ if (I == 16 && J == 8) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 8 && J == 8) {
+ return threadIdx.x / 4;
+ } else if constexpr (I == 16 && J == 4) {
+ return (l * 8) + (threadIdx.x / 4);
+ } else if constexpr (I == 16 && J == 8) {
+ return ((l % 2) * 8) + (threadIdx.x / 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 8) {
+ return (l * 4) + (threadIdx.x % 4);
+ } else if constexpr (I == 16 && J == 4) {
+ return threadIdx.x % 4;
+ } else if constexpr (I == 16 && J == 8) {
+ return ((l / 2) * 4) + (threadIdx.x % 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#endif // defined(AMD_WMMA_AVAILABLE)
+ };
+
+ template <int I_, int J_, typename T>
+ struct tile<I_, J_, T, DATA_LAYOUT_J_MAJOR> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR;
+
+ static constexpr int ne = tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::ne;
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_j(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile<I_, J_, T, DATA_LAYOUT_I_MAJOR>::get_i(l);
+ }
+ };
+
+ template <int I_, int J_, typename T>
+ struct tile<I_, J_, T, DATA_LAYOUT_I_MAJOR_MIRRORED> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+
+ // RDNA3
+ static constexpr int ne = I * J / 32 * 2;
+
+ T x[ne] = {0};
+
+ static constexpr __device__ bool supported() {
+ if (I == 16 && J == 16) return true;
+ if (I == 16 && J == 8) return true;
+ if (I == 16 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
+ if constexpr (supported()) {
+ return threadIdx.x % 16;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (supported()) {
+ return l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+ };
+
+ template <int I_, int J_>
+ struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+#if defined(RDNA3)
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
+
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
+ }
+#else // Volta
+ static constexpr int ne = I * J / (WARP_SIZE/4);
+
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 8 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int /*l*/) {
+ if constexpr (I == 8 && J == 4) {
+ return ((threadIdx.x / 16) * 4) + (threadIdx.x % 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return l;
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+#endif // defined(RDNA3)
+ };
+
+ template <int I_, int J_>
+ struct tile<I_, J_, nv_bfloat162, DATA_LAYOUT_I_MAJOR_MIRRORED> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_MIRRORED;
+ static constexpr int ne = tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::ne;
+
+ nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::supported();
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_i(l);
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ return tile<I_, J_, float, DATA_LAYOUT_I_MAJOR_MIRRORED>::get_j(l);
+ }
+ };
+
+ template <int I_, int J_>
+ struct tile<I_, J_, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> {
+ static constexpr int I = I_;
+ static constexpr int J = J_;
+ static constexpr data_layout dl = DATA_LAYOUT_J_MAJOR_MIRRORED;
+ static constexpr int ne = I * J / (WARP_SIZE/4);
+
+ half2 x[ne] = {{0.0f, 0.0f}};
+
+ static constexpr __device__ bool supported() {
+ if (I == 8 && J == 4) return true;
+ return false;
+ }
+
+ static __device__ __forceinline__ int get_i(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return ((l / 2) * 4) + (threadIdx.x % 4);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+
+ static __device__ __forceinline__ int get_j(const int l) {
+ if constexpr (I == 8 && J == 4) {
+ return ((threadIdx.x / 16) * 2) + (l % 2);
+ } else {
+ NO_DEVICE_CODE;
+ return -1;
+ }
+ }
+ };
+
+#if defined(TURING_MMA_AVAILABLE)
+ template <int I, int J>
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+ tile<I, J/2, half2> ret;
+#pragma unroll
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+ ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+ }
+ return ret;
+ }
+
+ static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+ tile<8, 8, half2> ret;
+ ret.x[0] = ggml_cuda_movmatrix(t.x[0]);
+ ret.x[1] = ggml_cuda_movmatrix(t.x[1]);
+
+ return ret;
+ }
+#elif defined(AMD_WMMA_AVAILABLE)
+ template <int I, int J>
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+ tile<I, J/2, half2> ret;
+#pragma unroll
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 2) {
+ ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+ }
+ return ret;
+ }
+
+ static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) {
+ NO_DEVICE_CODE;
+ return tile<8, 8, half2>{};
+ }
+#else // Volta
+ template <int I, int J>
+ static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) {
+ tile<I, J/2, half2> ret;
+#pragma unroll
+ for (int l0 = 0; l0 < tile_float.ne; l0 += 4) {
+ ret.x[l0/2 + 0] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]);
+ ret.x[l0/2 + 1] = make_half2(tile_float.x[l0 + 2], tile_float.x[l0 + 3]);
+
+ // On Volta FP16 and FP32 tiles have a different memory layout,
+ // for the conversion threads with an offset of 2 need to exchange half their values:
+ ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)] = __shfl_xor_sync(
+ 0xFFFFFFFF, ret.x[l0/2 + (((threadIdx.x % 4) / 2) ^ 1)], 2, WARP_SIZE);
+ }
+ return ret;
+ }
+#endif // defined(TURING_MMA_AVAILABLE)
+
+ static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) {
+#if defined(RDNA4)
+ const int row = t.get_i(0);
+ const int left_right = t.get_j(0) / 4;
+ const int up_down = row / 8;
+ const int idx = row % 8;
+ reinterpret_cast<half*>(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f;
+#else
+ GGML_UNUSED_VARS(t);
+ NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+ }
+
+ template <int I, int J, typename T, data_layout dl>
+ static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) {
+#if defined(AMD_MFMA_AVAILABLE)
+ if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
+#pragma unroll
+ for (int l = 0; l < t.ne; ++l) {
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+ }
+ } else {
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+ }
+#elif defined(AMD_WMMA_AVAILABLE)
+ // All wmma layout has contiguous data when i-major.
+ if constexpr (is_i_major(dl)) {
+ // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes()
+ constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes();
+ if constexpr (sizeof(t.x) > aligned_copy_bytes) {
+ static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size");
+ constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes;
+#pragma unroll
+ for (int i = 0; i < aligned_copy_count; ++i) {
+ ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i));
+ }
+ } else {
+ ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0));
+ }
+ } else {
+#pragma unroll
+ for (int l = 0; l < t.ne; ++l) {
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+ }
+ }
+#else
+#pragma unroll
+ for (int l = 0; l < t.ne; ++l) {
+ t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
+ }
+#endif // defined(AMD_MFMA_AVAILABLE)
+ }
+
+ template <typename T>
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<8, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef TURING_MMA_AVAILABLE
+ int * xi = (int *) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + ((threadIdx.x / t.I) * (t.J / 2)) % t.J;
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+ : "=r"(xi[0]), "=r"(xi[1])
+ : "l"(xs));
+#else
+ load_generic(t, xs0, stride);
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ template <typename T>
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef TURING_MMA_AVAILABLE
+ int * xi = (int *) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride;
+ asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
+ : "=r"(xi[0]), "=r"(xi[1])
+ : "l"(xs));
+#else
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ GGML_UNUSED_VARS(t, xs0, stride);
+ NO_DEVICE_CODE;
+#else
+ load_generic(t, xs0, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ template <typename T, data_layout dl>
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<16, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) {
+#if defined(TURING_MMA_AVAILABLE)
+ int * xi = (int * ) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
+ : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3])
+ : "l"(xs));
+#else
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#if 1
+ // TODO: more generic handling
+ static_assert(sizeof(T) == 4, "bad type size");
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0);
+ ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4);
+#else
+ load_generic(t, xs0, stride);
+#endif // 1
+#else
+ load_generic(t, xs0, stride);
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+ }
+
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & t, const half2 * __restrict__ xs0, const int stride) {
+#pragma unroll
+ for (int l0 = 0; l0 < t.ne; l0 += 2) {
+ ggml_cuda_memcpy_1<2*sizeof(half2)>(t.x + l0, xs0 + t.get_i(l0)*stride + t.get_j(l0));
+ }
+ }
+
+ static __device__ __forceinline__ void load_ldmatrix(
+ tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride);
+#else
+ GGML_UNUSED_VARS(t, xs0, stride);
+ NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ }
+
+ template <typename T>
+ static __device__ __forceinline__ void load_ldmatrix_trans(
+ tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
+#ifdef TURING_MMA_AVAILABLE
+ int * xi = (int * ) t.x;
+ const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
+ asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
+ : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3])
+ : "l"(xs));
+#else
+ GGML_UNUSED_VARS(t, xs0, stride);
+ NO_DEVICE_CODE;
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, int> & D, const tile<16, 4, int> & A, const tile<8, 4, int> & B) {
+#ifdef TURING_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[0]), "r"(A.x[1]), "r"(B.x[0]));
+#else
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k16 mma instead:
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(D.x[0]), "+r"(D.x[1])
+ : "r"(A.x[0]), "r"(B.x[0]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[1]), "r"(B.x[0]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, int> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B) {
+#ifdef TURING_MMA_AVAILABLE
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+ : "+r"(D.x[0]), "+r"(D.x[1]), "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[0]), "r"(A.x[1]), "r"(A.x[2]), "r"(A.x[3]), "r"(B.x[0]), "r"(B.x[1]));
+#else
+ // On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(D.x[0]), "+r"(D.x[1])
+ : "r"(A.x[0]), "r"(B.x[0]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[1]), "r"(B.x[0]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(D.x[0]), "+r"(D.x[1])
+ : "r"(A.x[2]), "r"(B.x[1]));
+ asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
+ : "+r"(D.x[2]), "+r"(D.x[3])
+ : "r"(A.x[3]), "r"(B.x[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<16, 4, half2> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
+#ifdef TURING_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef TURING_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+ // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+ using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
+ halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]);
+ const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
+ const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ template <data_layout dl_ab, data_layout dl_d>
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
+ }
+
+ template <data_layout dl_ab, data_layout dl_d>
+ static __device__ __forceinline__ void mma(
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
+#ifdef AMD_MFMA_AVAILABLE
+ using floatx4_t = __attribute__((ext_vector_type(4))) float;
+ floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
+#if defined(CDNA3)
+ using floatx2_t = __attribute__((ext_vector_type(2))) float;
+ const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
+ const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
+#elif defined(CDNA2) || defined(CDNA1)
+#pragma unroll
+ for (int i = 0; i < 2; ++i) {
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
+ }
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // defined(CDNA3)
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
+ const tile<16, 8, int> & A,
+ const tile<8, 8, int> & B,
+ uint32_t a_scale,
+ uint32_t b_scale) {
+#ifdef BLACKWELL_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ float * Dxi = (float *) D.x;
+
+ asm volatile(
+ "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
+ "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
+ "%10, {0, 0}, %11, {0, 0};"
+ : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
+#else
+ GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
+#endif // BLACKWELL_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
+#ifdef TURING_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+ // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<16, 8, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<8, 8, nv_bfloat162> & B) {
+#ifdef AMPERE_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // AMPERE_MMA_AVAILABLE
+ }
+
+ template <data_layout dl_ab, data_layout dl_d>
+ static __device__ __forceinline__ void mma(
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, half2, dl_ab> & A, const tile<16, 8, half2, dl_ab> & B) {
+#ifdef TURING_MMA_AVAILABLE
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+ // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#elif defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+ using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+ const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
+ const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
+#elif defined(RDNA3)
+ using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+ const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
+ const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // RDNA4
+#elif defined(AMD_MFMA_AVAILABLE)
+ using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
+ using floatx4_t = __attribute__((ext_vector_type(4))) float;
+ floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
+ const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
+ const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // TURING_MMA_AVAILABLE
+ }
+
+ template <data_layout dl_ab, data_layout dl_d>
+ static __device__ __forceinline__ void mma(
+ tile<16, 16, float, dl_d> & D, const tile<16, 8, nv_bfloat162, dl_ab> & A, const tile<16, 8, nv_bfloat162, dl_ab> & B) {
+#if defined(AMD_WMMA_AVAILABLE)
+#if defined(RDNA4)
+ using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+ const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
+ const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
+#elif defined(RDNA3)
+ using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
+ using floatx8_t = __attribute__((ext_vector_type(8))) float;
+ floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
+ const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
+ const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // defined(RDNA4)
+#elif defined(AMD_MFMA_AVAILABLE)
+ using floatx4_t = __attribute__((ext_vector_type(4))) float;
+ floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
+#if defined(CDNA3) || defined(CDNA2)
+ using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
+ const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
+ const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
+#elif defined(CDNA1)
+#pragma unroll
+ for (int i = 0; i < 2; ++i) {
+ using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
+ const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
+ const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
+ acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
+ }
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // defined(CDNA3) || defined(CDNA2)
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
+
+ template <data_layout dl_d, data_layout dl_ab>
+ static __device__ __forceinline__ void mma(
+ tile<16, 16, int, dl_d> & D, const tile<16, 8, int, dl_ab> & A, const tile<16, 8, int, dl_ab> & B) {
+#if defined(AMD_MFMA_AVAILABLE)
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
+ int32x4_t * acc = (int32x4_t *) D.x;
+#if defined(CDNA3)
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
+ ((int64_t *) B.x)[0],
+ acc[0],
+ 0, 0, 0);
+#elif defined(CDNA2) || defined(CDNA)
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
+ B.x[0],
+ acc[0],
+ 0, 0, 0);
+ acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
+ B.x[1],
+ acc[0],
+ 0, 0, 0);
+#endif // defined(CDNA3)
+
+#elif defined(AMD_WMMA_AVAILABLE)
+
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
+ int32x8_t * acc = (int32x8_t *) D.x;
+
+#if defined(RDNA4)
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
+ int32x2_t * a_vec = (int32x2_t *) A.x;
+ int32x2_t * b_vec = (int32x2_t *) B.x;
+
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+ true,
+ a_vec[0],
+ true,
+ b_vec[0],
+ acc[0],
+ true
+ );
+
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+ true,
+ a_vec[1],
+ true,
+ b_vec[1],
+ acc[0],
+ true
+ );
+
+#elif defined(RDNA3)
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
+ int32x4_t * a_vec = (int32x4_t *) A.x;
+ int32x4_t * b_vec = (int32x4_t *) B.x;
+
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
+ true,
+ a_vec[0],
+ true,
+ b_vec[0],
+ acc[0],
+ true
+ );
+
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
+ true,
+ a_vec[1],
+ true,
+ b_vec[1],
+ acc[0],
+ true
+ );
+#endif // RDNA4
+
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
+#if defined(AMD_MFMA_AVAILABLE)
+ using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
+ int32x16_t * acc = (int32x16_t *) D.x;
+#if defined(CDNA3)
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
+ ((int64_t *) B.x)[0],
+ acc[0],
+ 0, 0, 0);
+#elif defined(CDNA2) || defined(CDNA)
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
+ B.x[0],
+ acc[0],
+ 0, 0, 0);
+ acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
+ B.x[1],
+ acc[0],
+ 0, 0, 0);
+#endif // defined(CDNA3)
+
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE
+ }
+
+ template <typename T1, typename T2, int J, int K>
+ static __device__ __forceinline__ void mma(
+ tile<32, J, T1> & D, const tile<32, K, T2> & A, const tile<J, K, T2> & B) {
+ tile <16, J, T1> * D16 = reinterpret_cast< tile<16, J, T1> *>(&D);
+ const tile<16, K, T2> * A16 = reinterpret_cast<const tile<16, K, T2> *>(&A);
+ mma(D16[0], A16[0], B);
+ mma(D16[1], A16[1], B);
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
+ asm("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 "
+ "{%0, %1, %2, %3, %4, %5, %6, %7}, {%8, %9}, {%10, %11}, {%0, %1, %2, %3, %4, %5, %6, %7};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+ }
+
+ static __device__ __forceinline__ void mma(
+ tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+ const int * Axi = (const int *) A.x;
+ const int * Bxi = (const int *) B.x;
+ int * Dxi = (int *) D.x;
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]), "r"(Bxi[1]));
+ asm("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16 "
+ "{%0, %1, %2, %3}, {%4, %5}, {%6, %7}, {%0, %1, %2, %3};"
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]), "r"(Bxi[3]));
+#else
+ GGML_UNUSED_VARS(D, A, B);
+ NO_DEVICE_CODE;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+ }
+
+ template <data_layout dl_d, data_layout dl_ab>
+ static __device__ __forceinline__ void mma(
+ tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) {
+#if defined(AMD_WMMA_AVAILABLE)
+ using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
+ int32x8_t * acc = (int32x8_t *) D.x;
+#if defined(RDNA4)
+ using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
+ int32x2_t * a_vec = (int32x2_t *) A.x;
+ int32x2_t * b_vec = (int32x2_t *) B.x;
+
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
+ true,
+ a_vec[0],
+ true,
+ b_vec[0],
+ acc[0],
+ false
+ );
+#elif defined(RDNA3)
+ using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
+ int32x4_t * a_vec = (int32x4_t *) A.x;
+ int32x4_t * b_vec = (int32x4_t *) B.x;
+
+ acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
+ true,
+ a_vec[0],
+ true,
+ b_vec[0],
+ acc[0],
+ false
+ );
+#endif // RDNA4
+#else
+ GGML_UNUSED(D);
+ GGML_UNUSED(A);
+ GGML_UNUSED(B);
+ NO_DEVICE_CODE;
+#endif // AMD_WMMA_AVAILABLE
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmf.cu b/llama.cpp/ggml/src/ggml-cuda/mmf.cu
new file mode 100644
index 0000000..aad4c34
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmf.cu
@@ -0,0 +1,191 @@
+#include "ggml.h"
+#include "mmf.cuh"
+#include "mmid.cuh"
+
+static __forceinline__ int mmf_get_rows_per_block(const int cc) {
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
+ return MMF_ROWS_PER_BLOCK_CDNA;
+ } else {
+ return MMF_ROWS_PER_BLOCK;
+ }
+}
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const size_t ts_src0 = ggml_type_size(src0->type);
+ const size_t ts_src1 = ggml_type_size(src1->type);
+ const size_t ts_dst = ggml_type_size(dst->type);
+
+ GGML_ASSERT(ne13 == ne3);
+
+ GGML_ASSERT( nb00 == ts_src0);
+ GGML_ASSERT( nb10 == ts_src1);
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+ GGML_ASSERT( nb0 == ts_dst);
+
+ const float * src1_d = (const float *) src1->data;
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
+ float * dst_d = (float *) dst->data;
+
+ const int64_t s01 = src0->nb[1] / ts_src0;
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s1 = dst->nb[1] / ts_dst;
+ const int64_t s02 = src0->nb[2] / ts_src0;
+ const int64_t s12 = src1->nb[2] / ts_src1;
+ const int64_t s2 = dst->nb[2] / ts_dst;
+ const int64_t s03 = src0->nb[3] / ts_src0;
+ const int64_t s13 = src1->nb[3] / ts_src1;
+ const int64_t s3 = dst->nb[3] / ts_dst;
+
+ const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
+ const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
+ mmf_ids_data ids_info{};
+ mmf_ids_data * ids_info_ptr = nullptr;
+ ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
+ ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
+ ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
+
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+ const int64_t ncols_dst = ids ? ne2 : ne1;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
+
+ const int64_t stride_col_dst = ids ? s2 : s1;
+ const int64_t stride_col_y = ids ? s12 : s11;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
+
+ int64_t stride_channel_y = ids ? s11 : s12;
+ int64_t nchannels_y = ids ? ne11 : ne12;
+
+ //mul_mat_id: handle broadcast
+ if (ids && nchannels_y == 1) {
+ stride_channel_y = 0;
+ nchannels_y = ids->ne[0];
+ }
+
+ if (ids && ncols_dst > 16) {
+ const int64_t n_expert_used = ids->ne[0];
+ const int64_t n_experts = ne02;
+ const int64_t n_tokens = ne12;
+ const int64_t ne_get_rows = n_tokens * n_expert_used;
+
+ ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
+ ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
+ expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
+
+ const int si1 = static_cast<int>(ids_s1);
+ const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
+
+ GGML_ASSERT(sis1 > 0);
+
+ ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
+ static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
+ CUDA_CHECK(cudaGetLastError());
+
+ ids_info.ids_src_compact = ids_src_compact_dev.get();
+ ids_info.ids_dst_compact = ids_dst_compact_dev.get();
+ ids_info.expert_bounds_dev = expert_bounds_dev.get();
+ ids_info.n_experts = static_cast<int>(n_experts);
+ ids_info.sis1 = sis1;
+ ids_info_ptr = &ids_info;
+ }
+
+ const int device = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[device].cc;
+ const int rows_per_block = mmf_get_rows_per_block(cc);
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: {
+ const float * src0_d = (const float *) src0->data;
+ constexpr int vals_per_T = 1;
+ mul_mat_f_switch_rows_per_block<float>(
+ rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
+ } break;
+ case GGML_TYPE_F16: {
+ const half2 * src0_d = (const half2 *) src0->data;
+ constexpr int vals_per_T = 2;
+ mul_mat_f_switch_rows_per_block<half2>(
+ rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
+ } break;
+ case GGML_TYPE_BF16: {
+ const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
+ constexpr int vals_per_T = 2;
+ mul_mat_f_switch_rows_per_block<nv_bfloat162>(
+ rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
+ } break;
+ default:
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+ }
+}
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne,
+ const size_t * src0_nb, const int src1_ncols, bool mul_mat_id) {
+ if (ggml_is_quantized(type)) {
+ return false;
+ }
+
+ const size_t ts = ggml_type_size(type);
+ if (src0_ne[0] % (warp_size * (4/ts)) != 0) {
+ return false;
+ }
+
+ if (src0_nb[0] != ts) {
+ return false;
+ }
+
+ // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (src0_nb[i] % (2*ts) != 0) {
+ return false;
+ }
+ }
+ if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
+ return false;
+ }
+
+ if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
+ return false;
+ }
+
+ if (mul_mat_id) {
+ if (src0_ne[1] <= 1024 && src1_ncols > 512) {
+ return false;
+ } else if(src0_ne[1] > 1024 && src1_ncols > 128) {
+ return false;
+ }
+ } else {
+ if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
+ return false;
+ } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
+ //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
+ return false;
+ } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
+ return false;
+ } else if (src1_ncols > 16) {
+ return false;
+ }
+ }
+
+ switch (type) {
+ case GGML_TYPE_F32:
+ return ampere_mma_available(cc) || amd_mfma_available(cc);
+ case GGML_TYPE_F16:
+ return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
+ case GGML_TYPE_BF16:
+ return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
+ default:
+ return false;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmf.cuh b/llama.cpp/ggml/src/ggml-cuda/mmf.cuh
new file mode 100644
index 0000000..c2a8d54
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmf.cuh
@@ -0,0 +1,908 @@
+#pragma once
+
+#include "mma.cuh"
+#include "common.cuh"
+#include "convert.cuh"
+
+using namespace ggml_cuda_mma;
+
+#define MMF_ROWS_PER_BLOCK 32
+#define MMF_ROWS_PER_BLOCK_CDNA 64
+
+static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
+ return 512;
+ } else {
+ return 256;
+ }
+}
+
+static __forceinline__ int mmf_get_padding(int cc) {
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
+ return 2;
+ } else {
+ return 4;
+ }
+}
+
+static constexpr __device__ int mmf_get_padding() {
+#if defined(AMD_MFMA_AVAILABLE)
+ return 2;
+#else
+ return 4;
+#endif // defined(AMD_MFMA_AVAILABLE)
+}
+
+struct mmf_ids_data {
+ const int32_t * ids_src_compact = nullptr;
+ const int32_t * ids_dst_compact = nullptr;
+ const int32_t * expert_bounds_dev = nullptr;
+ int n_experts = 0;
+ int sis1 = 0;
+};
+
+void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const size_t * src0_nb, const int src1_ncols, bool mul_mat_id);
+
+template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f(
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+ const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+ const int stride_col_id, const int stride_row_id,
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE)
+ if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, get_input_data_layout()> tile_A;
+ typedef tile<16, 8, T, get_input_data_layout()> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
+#else
+#ifdef VOLTA_MMA_AVAILABLE
+ if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
+ typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
+ typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
+#else
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T> tile_A;
+ typedef tile<8, 8, T> tile_B;
+ typedef tile<16, 8, float> tile_C;
+#endif // VOLTA_MMA_AVAILABLE
+#endif // defined(AMD_WMMA_AVAILABLE)
+ if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ constexpr int tile_k_padded = warp_size + mmf_get_padding();
+ constexpr int ntA = rows_per_block / tile_A::I;
+ constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+ const int row0 = blockIdx.x * rows_per_block;
+
+ int expert_idx = 0;
+ int col_base = 0;
+
+ const int channel_dst = has_ids ? 0 : blockIdx.y;
+
+ if constexpr (has_ids) {
+ // experts + tiles of ncols_dst are packed in the y dimension
+ int col_tiles = (ncols_dst_total + cols_per_block - 1) / cols_per_block;
+ const int nchannels_x = gridDim.y / col_tiles;
+ const int tile_idx = blockIdx.y / nchannels_x;
+ expert_idx = blockIdx.y - tile_idx * nchannels_x;
+ col_base = tile_idx * cols_per_block;
+ }
+
+ const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
+ const int channel_y = channel_dst;
+ const int sample_dst = blockIdx.z;
+ const int sample_x = sample_dst / sample_ratio;
+ const int sample_y = sample_dst;
+
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
+ y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
+ dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
+
+ if constexpr (has_ids) {
+ constexpr int y_stride_scale = std::is_same_v<T, float> ? 1 : 2;
+ const int64_t col_offset = col_base;
+ y += col_offset * stride_col_y * y_stride_scale;
+ dst += col_offset * stride_col_dst;
+ ids += col_offset * stride_row_id;
+ }
+
+ const float2 * y2 = (const float2 *) y;
+
+ extern __shared__ char data_mmv[];
+
+ char * shmem_base = data_mmv;
+ int * slot_map = (int *) shmem_base;
+ char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
+
+ tile_C C[ntA][ntB];
+
+ T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
+
+ if constexpr (has_ids) {
+ int found = 0;
+
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (threadIdx.x == 0) {
+ slot_map[j] = -1;
+ }
+
+ if (col_base + j >= ncols_dst_total) {
+ continue;
+ }
+
+ const int32_t * __restrict__ id_row = ids + j*stride_row_id;
+
+ for (int k = threadIdx.x; k < nchannels_dst; k += warp_size) {
+ int match = id_row[k*stride_col_id] == expert_idx;
+
+ if (match) {
+ slot_map[j] = k;
+ found = 1;
+ break;
+ }
+ }
+ }
+
+ if (!__syncthreads_or(found)) {
+ return;
+ }
+ }
+
+
+ for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+ tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int i = 0; i < tile_A::I; ++i) {
+ tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+ load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+ }
+ }
+
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+ if constexpr (std::is_same_v<T, float>) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + itB*tile_B::I;
+
+ if constexpr (!has_ids) {
+ tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
+ } else {
+ const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
+ tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f;
+ }
+ }
+ } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + itB*tile_B::I;
+
+ if constexpr (!has_ids) {
+ const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
+ tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
+ } else {
+ const bool valid = j < cols_per_block && (col_base + j) < ncols_dst_total && slot_map[j] >= 0;
+ float2 tmp = valid ? *(const float2*) &y[slot_map[j]*stride_channel_y + 2*(j*stride_col_y + col)] : make_float2(0.0f, 0.0f);
+ tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
+ }
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+ tile_B B;
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+ }
+ }
+ }
+ }
+
+ float * buf_iw = (float *) compute_base;
+ constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+ const int j = itB*tile_C::J + tile_C::get_j(l);
+ buf_iw[j*kiw + i] = C[itA][itB].x[l];
+ }
+ }
+ }
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+ return;
+ }
+
+ float sum[rows_per_block/warp_size] = {0.0f};
+ static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
+#pragma unroll
+ for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+#pragma unroll
+ for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+ const int i = i0 + i1*warp_size + threadIdx.x;
+
+ sum[i1] += buf_iw[j*kiw + i];
+ }
+ }
+
+ if constexpr (!has_ids) {
+#pragma unroll
+ for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+ dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+ }
+ } else {
+ const int slot = (j < cols_per_block) ? slot_map[j] : -1;
+ if (slot >= 0 && (col_base + j) < ncols_dst_total) {
+#pragma unroll
+ for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+ dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+ }
+ }
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, ids, dst,
+ ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ NO_DEVICE_CODE;
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+}
+
+//This kernel is for larger batch sizes of mul_mat_id
+template <typename T, int rows_per_block, int cols_per_block, int nwarps>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f_ids(
+ const T * __restrict__ x, const float * __restrict__ y,
+ const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
+ const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
+ const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const uint3 sis1_fd, const uint3 nch_fd) {
+// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE)
+ if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, get_input_data_layout()> tile_A;
+ typedef tile<16, 8, T, get_input_data_layout()> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
+#else
+#ifdef VOLTA_MMA_AVAILABLE
+ if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
+ typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
+ typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
+#else
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T> tile_A;
+ typedef tile<8, 8, T> tile_B;
+ typedef tile<16, 8, float> tile_C;
+#endif // VOLTA_MMA_AVAILABLE
+#endif // defined(AMD_WMMA_AVAILABLE)
+ if constexpr (!tile_A::supported() || !tile_B::supported() || !tile_C::supported()) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ constexpr int tile_k_padded = warp_size + mmf_get_padding();
+ constexpr int ntA = rows_per_block / tile_A::I;
+ constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+ const int row0 = blockIdx.x * rows_per_block;
+
+ const int expert_idx = blockIdx.y;
+ const int expert_start = expert_bounds[expert_idx];
+ const int expert_end = expert_bounds[expert_idx + 1];
+ const int ncols_expert = expert_end - expert_start;
+
+ const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
+ const int tile_idx = blockIdx.z;
+ if (tile_idx >= tiles_for_expert) {
+ return;
+ }
+
+ const int col_base = tile_idx * cols_per_block;
+
+ GGML_UNUSED(channel_ratio);
+
+ const int channel_x = expert_idx;
+ const int sample_dst = 0;
+ const int sample_x = sample_dst / sample_ratio;
+ const int sample_y = sample_dst;
+
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row;
+ y += int64_t(sample_y) *stride_sample_y;
+ dst += int64_t(sample_dst)*stride_sample_dst;
+
+ const int32_t * ids_src_expert = ids_src_compact + expert_start;
+ const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
+
+ extern __shared__ char data_mmv[];
+ char * compute_base = data_mmv;
+
+ //const float2 * y2 = (const float2 *) y;
+
+ tile_C C[ntA][ntB];
+
+ T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
+
+ for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+ tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int i = 0; i < tile_A::I; ++i) {
+ tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+ load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+ }
+ }
+
+ if constexpr (std::is_same_v<T, float>) {
+ float vals_buf[2][tile_B::I];
+ auto gather_tile = [&](int tile_idx_local, float *vals) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + tile_idx_local*tile_B::I;
+ const int global_j = col_base + j;
+ float val = 0.0f;
+ if (j < cols_per_block && global_j < ncols_expert) {
+ const int src_entry = ids_src_expert[global_j];
+ const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
+ const int token = (int) qrm.x;
+ const int channel = (int) qrm.y;
+ if (token < ncols_dst_total) {
+ val = y[channel*stride_channel_y + token*stride_col_y + col];
+ }
+ }
+ vals[j0] = val;
+ }
+ };
+
+ gather_tile(0, vals_buf[0]);
+
+ int curr_buf = 0;
+ int next_buf = 1;
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
+ }
+
+ if (itB + 1 < ntB) {
+ gather_tile(itB + 1, vals_buf[next_buf]);
+ }
+
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+ tile_B B;
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+ }
+ }
+
+ if (itB + 1 < ntB) {
+ curr_buf ^= 1;
+ next_buf ^= 1;
+ }
+ }
+ } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+ float2 vals_buf[2][tile_B::I];
+ auto gather_tile = [&](int tile_idx_local, float2 *vals) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + tile_idx_local*tile_B::I;
+ const int global_j = col_base + j;
+ float2 tmp = make_float2(0.0f, 0.0f);
+ if (j < cols_per_block && global_j < ncols_expert) {
+ const int src_entry = ids_src_expert[global_j];
+ const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
+ const int token = (int) qrm.x;
+ const int channel = (int) qrm.y;
+ if (token < ncols_dst_total) {
+ tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
+ }
+ }
+ vals[j0] = tmp;
+ }
+ };
+
+ if (ntB > 0) {
+ gather_tile(0, vals_buf[0]);
+ }
+
+ int curr_buf = 0;
+ int next_buf = 1;
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const float2 tmp = vals_buf[curr_buf][j0];
+ tile_xy[j0*tile_k_padded + threadIdx.x] = ggml_cuda_cast<T>(tmp);
+ }
+
+ if (itB + 1 < ntB) {
+ gather_tile(itB + 1, vals_buf[next_buf]);
+ }
+
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+ tile_B B;
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+ }
+ }
+
+ if (itB + 1 < ntB) {
+ curr_buf ^= 1;
+ next_buf ^= 1;
+ }
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+ }
+
+ float * buf_iw = (float *) compute_base;
+ constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+ const int j = itB*tile_C::J + tile_C::get_j(l);
+ buf_iw[j*kiw + i] = C[itA][itB].x[l];
+ }
+ }
+ }
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+ return;
+ }
+
+ float sum[rows_per_block/warp_size] = {0.0f};
+ static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
+#pragma unroll
+ for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+#pragma unroll
+ for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+ const int i = i0 + i1*warp_size + threadIdx.x;
+
+ sum[i1] += buf_iw[j * kiw + i];
+ }
+ }
+
+ const int global_j = col_base + j;
+ if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
+ const int dst_entry = ids_dst_expert[global_j];
+ const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
+ const int token = (int) qrm.x;
+ if (token < ncols_dst_total) {
+ const int slot = (int) qrm.y;
+#pragma unroll
+ for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+ dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+ }
+ }
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
+ ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
+ NO_DEVICE_CODE;
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+}
+
+template<typename T, int rows_per_block, int cols_per_block, int nwarps>
+static inline void mul_mat_f_switch_ids(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int64_t stride_row_id,
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
+ const mmf_ids_data * ids_data) {
+ const bool has_ids_data = ids_data && ids_data->ids_src_compact;
+
+ // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
+ // we prefer the normal mul_mat_f path with has_ids=true.
+ if (has_ids_data && ncols_dst > 16) {
+ const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
+ if (max_tiles == 0) {
+ return;
+ }
+ dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
+
+ const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
+ const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
+
+ mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
+ (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
+ ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
+ sis1_fd, nch_fd);
+ } else if (ids) {
+ const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
+ dim3 block_nums_ids = block_nums;
+ block_nums_ids.y *= col_tiles;
+
+ mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
+ (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } else {
+ mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
+ (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ }
+}
+
+template <typename T, int rows_per_block, int cols_per_block>
+void mul_mat_f_cuda(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int64_t stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream, const mmf_ids_data * ids_data) {
+ typedef tile<16, 8, T> tile_A_16;
+ typedef tile<32, 8, T> tile_A_32;
+ typedef tile<16, 8, T> tile_B_16;
+ typedef tile< 8, 8, T> tile_B_8;
+
+ GGML_ASSERT(ncols_x % 2 == 0);
+ GGML_ASSERT(stride_row % 2 == 0);
+ GGML_ASSERT(stride_col_y % 2 == 0);
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
+
+ const int device = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[device].cc;
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+ int64_t nwarps_best = 1;
+ int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
+ int64_t max_block_size = mmf_get_max_block_size(cc);
+ for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
+ const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
+ if (niter < niter_best) {
+ niter_best = niter;
+ nwarps_best = nwarps;
+ }
+ }
+
+ const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
+ const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
+ const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
+ const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
+ const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
+ const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
+ const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
+
+ const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
+ const dim3 block_dims(warp_size, nwarps_best, 1);
+
+ switch (nwarps_best) {
+ case 1: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ case 2: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ case 3: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ case 4: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ case 5: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ case 6: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ case 7: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ case 8: {
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
+ x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+ ids_data);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+
+ GGML_UNUSED_VARS(nchannels_y);
+}
+
+template <typename T, int rows_per_block>
+static void mul_mat_f_switch_cols_per_block(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream, const mmf_ids_data * ids_data) {
+
+ const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
+
+ GGML_ASSERT(ids || ncols_dst <= 16);
+
+ switch (ncols_case) {
+ case 1: {
+ mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 2: {
+ mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 3: {
+ mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 4: {
+ mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 5: {
+ mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 6: {
+ mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 7: {
+ mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 8: {
+ mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 9: {
+ mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 10: {
+ mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 11: {
+ mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 12: {
+ mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 13: {
+ mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 14: {
+ mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 15: {
+ mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case 16: {
+ mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+}
+
+template <typename T>
+static void mul_mat_f_switch_rows_per_block(
+ const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream, const mmf_ids_data * ids_data) {
+ switch (rows_per_block) {
+ case MMF_ROWS_PER_BLOCK: {
+ mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
+ x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case MMF_ROWS_PER_BLOCK_CDNA: {
+ mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
+ x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ default:
+ GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
+ }
+}
+
+#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
+ template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
+ const T * x, const float * y, const int32_t * ids, float * dst, \
+ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
+ const int64_t stride_col_id, const int64_t stride_row_id, \
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
+ cudaStream_t stream, const mmf_ids_data * ids_data);
+
+#if !defined(GGML_USE_MUSA)
+#define DECL_MMF_CASE_EXTERN(ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
+
+#define DECL_MMF_CASE(ncols_dst) \
+ DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
+
+DECL_MMF_CASE_EXTERN(1);
+DECL_MMF_CASE_EXTERN(2);
+DECL_MMF_CASE_EXTERN(3);
+DECL_MMF_CASE_EXTERN(4);
+DECL_MMF_CASE_EXTERN(5);
+DECL_MMF_CASE_EXTERN(6);
+DECL_MMF_CASE_EXTERN(7);
+DECL_MMF_CASE_EXTERN(8);
+DECL_MMF_CASE_EXTERN(9);
+DECL_MMF_CASE_EXTERN(10);
+DECL_MMF_CASE_EXTERN(11);
+DECL_MMF_CASE_EXTERN(12);
+DECL_MMF_CASE_EXTERN(13);
+DECL_MMF_CASE_EXTERN(14);
+DECL_MMF_CASE_EXTERN(15);
+DECL_MMF_CASE_EXTERN(16);
+#else
+#define DECL_MMF_CASE(ncols_dst)
+#endif
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmid.cu b/llama.cpp/ggml/src/ggml-cuda/mmid.cu
new file mode 100644
index 0000000..3c61e45
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmid.cu
@@ -0,0 +1,164 @@
+#include "common.cuh"
+#include "mmid.cuh"
+
+// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
+struct mm_ids_helper_store {
+ uint32_t data;
+
+ __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
+ data = (it & 0x003FFFFF) | (iex_used << 22);
+ }
+
+ __device__ uint32_t it() const {
+ return data & 0x003FFFFF;
+ }
+
+ __device__ uint32_t iex_used() const {
+ return data >> 22;
+ }
+};
+static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
+
+// Helper function for mul_mat_id, converts ids to a more convenient format.
+// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
+// ids_dst describes the same mapping but for the dst tensor.
+// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
+template <int n_expert_used_template>
+__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mm_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
+ const int expert = blockIdx.x;
+
+ extern __shared__ char data_mm_ids_helper[];
+ mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
+
+ int nex_prev = 0; // Number of columns for experts with a lower index.
+ int it_compact = 0; // Running index for the compact slice of this expert.
+
+ if constexpr (n_expert_used_template == 0) {
+ // Generic implementation:
+ for (int it = 0; it < n_tokens; ++it) {
+ int iex_used = -1; // The index at which the expert is used, if any.
+ for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
+ const int expert_used = ids[it*si1 + iex];
+ nex_prev += expert_used < expert;
+ if (expert_used == expert) {
+ iex_used = iex;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact] = mm_ids_helper_store(it, iex_used);
+ }
+
+ if (warp_reduce_any<warp_size>(iex_used != -1)) {
+ it_compact++;
+ }
+ }
+ } else {
+ // Implementation optimized for specific numbers of experts used:
+ static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
+ const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
+ for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
+ const int it = it0 + threadIdx.x / neu_padded;
+
+ const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
+ const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
+ ids[it*si1 + iex] : INT_MAX;
+ const int iex_used = expert_used == expert ? iex : -1;
+ nex_prev += expert_used < expert;
+
+ // Whether the threads at this token position have used the expert:
+ const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
+
+ // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
+ int it_compact_add_lower = 0;
+#pragma unroll
+ for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
+ const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
+ if (threadIdx.x >= static_cast<unsigned int>(offset)) {
+ it_compact_add_lower += tmp;
+ }
+ }
+
+ if (iex_used != -1) {
+ store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
+ }
+
+ // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
+ it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
+ }
+ }
+ nex_prev = warp_reduce_sum<warp_size>(nex_prev);
+
+ for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
+ const mm_ids_helper_store store_it = store[itc];
+ const int it = store_it.it();
+ const int iex_used = store_it.iex_used();
+ ids_src1[nex_prev + itc] = it*sis1 + iex_used % nchannels_y;
+ ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
+ }
+
+ if (threadIdx.x != 0) {
+ return;
+ }
+
+ expert_bounds[expert] = nex_prev;
+
+ if (expert < static_cast<int>(gridDim.x) - 1) {
+ return;
+ }
+
+ expert_bounds[gridDim.x] = nex_prev + it_compact;
+}
+
+template <int n_expert_used_template>
+static void launch_mm_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+ GGML_ASSERT(n_tokens < (1 << 22) && "too few bits in mm_ids_helper_store");
+ GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
+
+ const int id = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
+
+ const dim3 num_blocks(n_experts, 1, 1);
+ const dim3 block_size(warp_size, 1, 1);
+ const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
+ GGML_ASSERT(nbytes_shared <= smpbo);
+ mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
+ (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
+}
+
+void ggml_cuda_launch_mm_ids_helper(
+ const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+ const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+ switch (n_expert_used) {
+ case 2:
+ launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 4:
+ launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 6:
+ launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 8:
+ launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 16:
+ launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ case 32:
+ launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ default:
+ launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+ break;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmid.cuh b/llama.cpp/ggml/src/ggml-cuda/mmid.cuh
new file mode 100644
index 0000000..ac090ae
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmid.cuh
@@ -0,0 +1,5 @@
+#pragma once
+
+void ggml_cuda_launch_mm_ids_helper(
+ const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
+ int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmq.cu b/llama.cpp/ggml/src/ggml-cuda/mmq.cu
new file mode 100644
index 0000000..9a69f41
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmq.cu
@@ -0,0 +1,366 @@
+#include "common.cuh"
+#include "mmq.cuh"
+#include "quantize.cuh"
+#include "mmid.cuh"
+
+static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+ switch (args.type_x) {
+ case GGML_TYPE_Q4_0:
+ mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
+ break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ mul_mat_q_case<GGML_TYPE_IQ2_XXS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ mul_mat_q_case<GGML_TYPE_IQ2_XS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ2_S:
+ mul_mat_q_case<GGML_TYPE_IQ2_S>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ mul_mat_q_case<GGML_TYPE_IQ3_XXS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ3_S:
+ mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ1_S:
+ mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
+ break;
+ case GGML_TYPE_IQ4_NL:
+ mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+void ggml_cuda_mul_mat_q(
+ ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ cudaStream_t stream = ctx.stream();
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+ const size_t ts_src0 = ggml_type_size(src0->type);
+ const size_t ts_src1 = ggml_type_size(src1->type);
+ const size_t ts_dst = ggml_type_size(dst->type);
+
+ GGML_ASSERT( nb00 == ts_src0);
+ GGML_ASSERT( nb10 == ts_src1);
+ GGML_ASSERT( nb0 == ts_dst);
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+
+ const char * src0_d = (const char *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ // If src0 is a temporary compute buffer, clear any potential padding.
+ if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+ const size_t size_data = ggml_nbytes(src0);
+ const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
+ if (size_alloc > size_data) {
+ GGML_ASSERT(ggml_is_contiguously_allocated(src0));
+ GGML_ASSERT(!src0->view_src);
+ CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
+ }
+ }
+
+ const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+ const int64_t s01 = src0->nb[1] / ts_src0;
+ const int64_t s1 = dst->nb[1] / ts_dst;
+ const int64_t s02 = src0->nb[2] / ts_src0;
+ const int64_t s2 = dst->nb[2] / ts_dst;
+ const int64_t s03 = src0->nb[3] / ts_src0;
+ const int64_t s3 = dst->nb[3] / ts_dst;
+
+ const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
+ || GGML_CUDA_CC_IS_CDNA(cc);
+
+ // TODO: tighter pool buffer size vs q8 path
+ const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
+
+ if (!ids) {
+ const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
+ get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
+ ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
+
+ {
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s12 = src1->nb[2] / ts_src1;
+ const int64_t s13 = src1->nb[3] / ts_src1;
+ if (use_native_mxfp4) {
+ static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
+ quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+ ne11, ne12, ne13, stream);
+
+ } else {
+ quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+ ne11, ne12, ne13, stream);
+ }
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ // Stride depends on quantization format
+ const int64_t s12 = use_native_mxfp4 ?
+ ne11 * ne10_padded * sizeof(block_fp4_mmq) /
+ (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32)
+ :
+ ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
+ const int64_t s13 = ne12*s12;
+
+ const mmq_args args = {
+ src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
+ ne00, ne01, ne1, s01, ne11, s1,
+ ne02, ne12, s02, s12, s2,
+ ne03, ne13, s03, s13, s3,
+ use_stream_k, ne1};
+ ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
+ return;
+ }
+
+ GGML_ASSERT(ne13 == 1);
+ GGML_ASSERT(nb12 % nb11 == 0);
+ GGML_ASSERT(nb2 % nb1 == 0);
+
+ const int64_t n_expert_used = ids->ne[0];
+ const int64_t ne_get_rows = ne12 * n_expert_used;
+ GGML_ASSERT(ne1 == n_expert_used);
+
+ ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
+ ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
+
+ {
+ GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
+ const int si1 = ids->nb[1] / ggml_element_size(ids);
+ const int sis1 = nb12 / nb11;
+
+ ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+ ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
+ get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
+ ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), nbytes_src1_q8_1);
+
+ const int64_t ne11_flat = ne12*n_expert_used;
+ const int64_t ne12_flat = 1;
+ const int64_t ne13_flat = 1;
+
+ {
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s12 = src1->nb[2] / ts_src1;
+ const int64_t s13 = src1->nb[3] / ts_src1;
+
+ if (use_native_mxfp4) {
+ quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+ ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+ } else {
+ quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+ ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+ }
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
+ ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
+ const int64_t s13 = ne12*s12;
+
+ // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
+ const mmq_args args = {
+ src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
+ ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
+ ne02, ne02, s02, s12, s2,
+ ne03, ne13, s03, s13, s3,
+ use_stream_k, ne12};
+
+ ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
+}
+
+void ggml_cuda_op_mul_mat_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ const int64_t ne00 = src0->ne[0];
+
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne11 = src1->ne[1];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne0 = dst->ne[0];
+
+ const int64_t row_diff = row_high - row_low;
+ const int64_t stride01 = ne00 / ggml_blck_size(src0->type);
+
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the kernel writes into
+ const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+ // The stream-k decomposition is only faster for recent NVIDIA GPUs.
+ // Also its fixup needs to allocate a temporary buffer in the memory pool.
+ // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
+ const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
+ || GGML_CUDA_CC_IS_CDNA(cc))
+ && src1_ncols == ne11;
+ const mmq_args args = {
+ src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
+ ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
+ 1, 1, 0, 0, 0,
+ 1, 1, 0, 0, 0,
+ use_stream_k, src1_ncols};
+
+ ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
+
+ GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_padded_row_size);
+}
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts) {
+#ifdef GGML_CUDA_FORCE_CUBLAS
+ return false;
+#endif // GGML_CUDA_FORCE_CUBLAS
+
+ bool mmq_supported;
+
+ switch (type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ case GGML_TYPE_IQ1_S:
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_NL:
+ mmq_supported = true;
+ break;
+ default:
+ mmq_supported = false;
+ break;
+ }
+
+ if (!mmq_supported) {
+ return false;
+ }
+
+ if (turing_mma_available(cc)) {
+ return true;
+ }
+
+ if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
+ return false;
+ }
+
+#ifdef GGML_CUDA_FORCE_MMQ
+ return true;
+#endif //GGML_CUDA_FORCE_MMQ
+
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+ }
+
+ if (amd_mfma_available(cc)) {
+ // As of ROCM 7.0 rocblas/tensile performs very poorly on CDNA3 and hipblaslt (via ROCBLAS_USE_HIPBLASLT)
+ // performs better but is currently suffering from a crash on this architecture.
+ // TODO: Revisit when hipblaslt is fixed on CDNA3
+ if (GGML_CUDA_CC_IS_CDNA3(cc)) {
+ return true;
+ }
+ if (n_experts > 64 || ne11 <= 128) {
+ return true;
+ }
+ if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) {
+ return true;
+ }
+ if (ne11 <= 256 && (type == GGML_TYPE_Q4_K || type == GGML_TYPE_Q5_K)) {
+ return true;
+ }
+ return false;
+ }
+
+ if (amd_wmma_available(cc)) {
+ if (GGML_CUDA_CC_IS_RDNA3(cc)) {
+ // High expert counts are almost always better on MMQ due to
+ // the synchronization overhead in the cuBLAS/hipBLAS path:
+ // https://github.com/ggml-org/llama.cpp/pull/18202
+ if (n_experts >= 64) {
+ return true;
+ }
+
+ // For some quantization types MMQ can have lower peak TOPS than hipBLAS
+ // so it's only faster for sufficiently small batch sizes:
+ switch (type) {
+ case GGML_TYPE_Q2_K:
+ return ne11 <= 128;
+ case GGML_TYPE_Q6_K:
+ return ne11 <= (GGML_CUDA_CC_IS_RDNA3_0(cc) ? 128 : 256);
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ return GGML_CUDA_CC_IS_RDNA3_5(cc) || ne11 <= 128;
+ default:
+ return true;
+ }
+ }
+
+ // For RDNA4 MMQ is consistently faster than dequantization + hipBLAS:
+ // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
+ return true;
+ }
+
+ return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmq.cuh b/llama.cpp/ggml/src/ggml-cuda/mmq.cuh
new file mode 100644
index 0000000..f80f98c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmq.cuh
@@ -0,0 +1,4092 @@
+#pragma once
+
+#include "common.cuh"
+#include "vecdotq.cuh"
+#include "mma.cuh"
+
+#include <climits>
+#include <cstdint>
+
+using namespace ggml_cuda_mma;
+
+#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
+#define MMQ_ITER_K 256
+#define MMQ_ITER_K_MXFP4_FP4 512
+#define MMQ_NWARPS 8
+
+typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
+typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
+typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
+ float * __restrict__ dst, const int stride, const int i_max, const int j_max);
+
+enum mmq_q8_1_ds_layout {
+ MMQ_Q8_1_DS_LAYOUT_D4,
+ MMQ_Q8_1_DS_LAYOUT_DS4,
+ MMQ_Q8_1_DS_LAYOUT_D2S6,
+};
+
+struct block_q8_1_mmq {
+ // The y float data is converted to a data layout that can simply be copied to shared memory as a contiguous block.
+ // The y float data is first grouped as blocks of 128 values.
+ // These blocks are then treated as individual data values and transposed.
+ //
+ // To avoid shared memory bank conflicts each block is padded with 16 bytes.
+ // This padding is also used to store block scales/partial sums.
+ // The scales multiplied with the quantized data are equal to the unquantized values.
+ // The partial sums are obtained by summing up a subgroup of the contained values (prior to quantization)
+ // and are only needed for performance reasons.
+ //
+ // The exact data stored depends on the x data type.
+ union {
+ float d4[4]; // 1 32 bit scale per 32 values, stored as d0,d1,d2,d3
+ half2 ds4[4]; // 1 16 bit scale + 1 16 bit partial sum per 32 values, stored as d0,s0,d1,s1,d2,s2,d3,s3
+ half d2s6[8]; // 1 16 bit scale per 64 values + 1 16 bit partial sum per 16 values for the first 96 values,
+ // stored as d0,d1,s1,s2,s3,s4,s5
+ };
+ int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
+};
+
+struct block_fp4_mmq {
+ uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
+};
+
+static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
+
+static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
+ switch (type_x) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q5_0:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q5_1:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q8_0:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_MXFP4:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q2_K:
+ return MMQ_Q8_1_DS_LAYOUT_D2S6;
+ case GGML_TYPE_Q3_K:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_Q6_K:
+ case GGML_TYPE_IQ2_XXS:
+ case GGML_TYPE_IQ2_XS:
+ case GGML_TYPE_IQ2_S:
+ case GGML_TYPE_IQ3_XXS:
+ case GGML_TYPE_IQ3_S:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_IQ1_S:
+ return MMQ_Q8_1_DS_LAYOUT_DS4;
+ case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_IQ4_NL:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+struct tile_x_sizes {
+ int qs;
+ int dm;
+ int sc;
+};
+
+static int get_mmq_x_max_host(const int cc) {
+ return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 :
+ GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
+#ifdef GGML_CUDA_FORCE_MMQ
+ 128 : 64;
+#else
+ MMQ_DP4A_MAX_BATCH_SIZE : 64;
+#endif // GGML_CUDA_FORCE_MMQ
+}
+
+static constexpr __device__ int get_mmq_x_max_device() {
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ return 128;
+#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+
+#if defined(GGML_USE_HIP)
+ return 64;
+#else // defined(GGML_USE_HIP)
+
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#ifdef GGML_CUDA_FORCE_MMQ
+ return 128;
+#else // GGML_CUDA_FORCE_MMQ
+ return MMQ_DP4A_MAX_BATCH_SIZE;
+#endif // GGML_CUDA_FORCE_MMQ
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+ return 64;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+
+#endif // defined(GGML_USE_HIP)
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+}
+
+static int get_mmq_y_host(const int cc) {
+ return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
+ ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
+}
+
+static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
+#if defined(BLACKWELL_MMA_AVAILABLE)
+ return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
+#else
+ return MMQ_ITER_K;
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
+}
+
+static constexpr __device__ int get_mmq_y_device() {
+#if defined(GGML_USE_HIP)
+#if defined(RDNA1)
+ return 64;
+#else
+ return 128;
+#endif // defined RDNA1
+#else
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+ return 128;
+#else
+ return 64;
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#endif // defined(GGML_USE_HIP)
+}
+
+// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes.
+// The K dimension of the tiles has either,
+// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K),
+// 32 bit elements for the quantized data (does not include scales).
+// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K.
+// The final tile size in K direction is padded to avoid shared memory bank conflicts,
+// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma.
+#define MMQ_TILE_NE_K 32
+
+#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0}
+#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0}
+#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0}
+#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0}
+#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0}
+#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0}
+#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8}
+
+static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
+ switch (type) {
+ case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
+ case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
+ case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
+ case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
+ case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
+ case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
+ case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
+ case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
+ case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
+ case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
+ case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
+ case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
+ default: return tile_x_sizes{0, 0, 0};
+ }
+}
+
+#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
+#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7)
+
+static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
+
+static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
+ // tile sizes are the same for Q8_1 and FP4 for blackwell
+ case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
+ case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
+ case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
+ case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
+ case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
+ default: return 0;
+ }
+}
+
+// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
+#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
+#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
+
+static int mmq_get_granularity_host(const int mmq_x, const int cc) {
+ if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
+ return mmq_x >= 128 ? 32 : 16;
+ } else if (turing_mma_available(cc) && mmq_x >= 48) {
+ return 16;
+ } else {
+ return 8;
+ }
+}
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+ return mmq_x >= 128 ? 32 : 16;
+}
+#elif defined(TURING_MMA_AVAILABLE)
+static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) {
+ return mmq_x >= 48 ? 16 : 8;
+}
+#else
+static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) {
+ return 8;
+}
+#endif // AMD_MFMA_AVAILABLE
+
+#if defined(GGML_USE_HIP)
+static int mmq_get_nwarps_host(const int cc, const int warp_size) {
+ return amd_mfma_available(cc) ? 8 : 256/warp_size;
+}
+#else
+static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
+ return 256/warp_size;
+}
+#endif // (GGML_USE_HIP)
+
+static constexpr __device__ int mmq_get_nwarps_device() {
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ return 8;
+#else
+ return 256/ggml_cuda_get_physical_warp_size();
+#endif // AMD_MFMA_AVAILABLE
+}
+
+// ------------------------------------------------------------
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI4_0;
+ const int kqsx = txi % QI4_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
+ const int qs0 = get_int_b2(bxi->qs, kqsx);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808);
+#else
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+ int u[2*VDR_Q4_0_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)];
+ }
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u,
+ x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI4_1;
+ const int kqsx = txi % QI4_1;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
+ const int qs0 = get_int_b4(bxi->qs, kqsx);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
+#else
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2);
+
+ int u[2*VDR_Q4_1_Q8_1_MMQ];
+
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
+ u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l];
+ u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)];
+ }
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
+ (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u,
+ x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI5_0;
+ const int kqsx = txi % QI5_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
+
+ const int ql = get_int_b2(bxi->qs, kqsx);
+ const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx);
+
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
+ qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
+
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
+ qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI5_1;
+ const int kqsx = txi % QI5_1;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
+
+ const int ql = get_int_b4(bxi->qs, kqsx);
+ const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx);
+
+ int qs0 = (ql >> 0) & 0x0F0F0F0F;
+ qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
+ qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
+ qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
+ qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
+
+ int qs1 = (ql >> 4) & 0x0F0F0F0F;
+ qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
+ qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
+ qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
+ qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm;
+#else
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp
+ constexpr int threads_per_row = 32;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI8_0;
+ const int kqsx = txi % QI8_0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx);
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d;
+#else
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI_MXFP4;
+ const int kqsx = txi % QI_MXFP4;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b1(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+ const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#else
+ x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check>
+static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
+ int * __restrict__ x_tile,
+ const int kbx0,
+ const int i_max,
+ const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ int * x_qs = (int *) x_tile;
+ uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+
+ const int txi = threadIdx.x;
+
+ constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
+
+ constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
+ constexpr int rows_per_warp = warp_size / threads_per_row;
+ const int kbx = txi % threads_per_row;
+ const int row_in_warp = txi / threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
+
+ if constexpr (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
+
+ // quantize_mxfp4_mmq permutes nibbles to match the quantized format
+ const int k0 = kbx * 4;
+ memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
+
+ // Load E8M0 scales: pack 2 consecutive scales into one uint32
+ if (kbx % 2 == 0) {
+ uint32_t e = bxi->e;
+ e |= ((bxi + 1)->e << 8);
+ x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K],
+ x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y, mmq_q8_1_ds_layout ds_layout>
+static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ float dB;
+ const int j = j0 + tile_C::get_j(0);
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+ dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ } else {
+ dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
+ const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB;
+ }
+ }
+ }
+ }
+#else
+ typedef tile<16, 8, int> tile_A;
+ typedef tile< 8, 8, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+ const half2 * y_ds = (const half2 *) y;
+
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_0];
+ float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0];
+
+ const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+ load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
+ }
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+ dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
+ tile_B B;
+ float dB[tile_C::ne/2];
+
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ } else {
+ dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n][k01/QI8_0], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA[n][l/2][k01/QI8_0]*dB[l%2];
+ }
+ }
+ }
+ }
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
+ const int * __restrict__ y,
+ float * __restrict__ sum,
+ const int k00) {
+ typedef tile<16, 8, int> tile_A;
+ typedef tile<8, 8, int> tile_B;
+ typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
+
+ // Match layout from load_tiles_mxfp4_fp4
+ const int * x_qs = (const int *) x;
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+ const int * y_qs = (const int *) y + 4;
+ const uint32_t * y_sc = (const uint32_t *) y;
+
+ // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
+ tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+ uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+
+ // Block scale
+ // Each thread has to point to a 4 byte scale value
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+ const int k0 = k00 + k01;
+
+ load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
+ MMQ_MMA_TILE_X_K_FP4);
+
+ // based on block-scaling document, 2 threads in each quad need to supply to the scale value
+ const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
+ scaleA[n][k01 / (2 * QI_MXFP4)] =
+ *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+ tile_B B;
+ uint32_t scaleB; // 2xN scales
+
+ load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
+
+ scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+
+ mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
+ }
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
+ (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_dm = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(l);
+ float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l];
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y;
+ }
+ }
+ }
+ }
+#else
+ typedef tile<16, 8, int> tile_A;
+ typedef tile< 8, 8, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_dm = (const half2 *) y;
+
+ tile_A A[ntx][MMQ_TILE_NE_K/QI8_1];
+ float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1];
+
+ const int i0 = (threadIdx.y/ntx)*rows_per_warp;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1);
+ }
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_A::I + tile_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]);
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+ tile_B B;
+ float2 dsB[tile_C::ne/2];
+
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
+
+ dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n][k01/QI8_1], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l];
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y;
+ }
+ }
+ }
+ }
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+}
+
+// Used for Q3_K, IQ2_S, and IQ2_XS
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl<QI8_0>(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0],
+ &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)],
+ y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+// Used for Q3_K, IQ2_S, and IQ2_XS:
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE)
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+ typedef tile<64, 2, int, input_layout> tile_load;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B[1];
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B[0]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 4, int, input_layout> tile_A;
+ typedef tile<16, 4, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB;
+ }
+ }
+ }
+ }
+#elif defined(TURING_MMA_AVAILABLE)
+
+ typedef tile<16, 4, int> tile_A;
+ typedef tile<16, 8, int> tile_A_8;
+ typedef tile< 8, 4, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
+
+ tile_A A[ntx][8];
+ float dA[ntx][tile_C::ne/2][8];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
+ const int k0 = k00 + k01;
+
+ load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K);
+ }
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ tile_B B[2];
+ float dB[tile_C::ne/2];
+
+ // Here load_generic is faster than load_ldmatrix.
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C[2];
+ mma(C[0], A[n][k01/4 + 0], B[0]);
+ mma(C[1], A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]);
+ }
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, sum, k00);
+ NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K);
+ constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride;
+
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
+
+#pragma unroll
+ for (int l = 0; l < QR2_K; ++l) {
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+ const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ const int sc_m = bxi->scales[kqsx];
+#ifdef FAST_FP16_AVAILABLE
+ const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
+#else
+ const float2 bxi_dmf = __half22float2(bxi->dm);
+ const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
+#endif // FAST_FP16_AVAILABLE
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik;
+#else
+ x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ float2 y_df[mmq_x/nwarps];
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ y_df[j0/nwarps] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ constexpr int ns = 2;
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+ }
+ }
+
+ // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
+ // As a workaround 2 separate loops are used instead.
+#pragma unroll
+ for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ constexpr int ns = 1;
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
+ &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
+ &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE)
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+ typedef tile<64, 2, int, input_layout> tile_load;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B[1];
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2;
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
+
+ tile_C Cm;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tile_A A1;
+ A1.x[0] = 0x01010101;
+ A1.x[1] = 0x01010101;
+ mma(Cm, A1, B[0]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C Cd;
+ mma(Cd, A[n], B[0]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
+ float tmp = Cd.x[l]*dm.x;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tmp -= Cm.x[l]*dm.y;
+ }
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 4, int, input_layout> tile_A;
+ typedef tile<16, 4, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y;
+ const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0
+ : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y
+ : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x);
+
+ tile_C Cm;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tile_A A1;
+#pragma unroll
+ for (int l = 0; l < tile_A::ne; ++l) {
+ A1.x[l] = 0x01010101;
+ }
+ mma(Cm, A1, B);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C Cd;
+ mma(Cd, A[n], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]);
+ float tmp = Cd.x[l]*dm.x;
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tmp -= Cm.x[l]*dm.y;
+ }
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB;
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB;
+ }
+ }
+ }
+ }
+#elif defined(TURING_MMA_AVAILABLE)
+
+ typedef tile<16, 4, int> tile_A;
+ typedef tile<16, 8, int> tile_A_8;
+ typedef tile< 8, 4, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
+
+ tile_A A[ntx][8];
+ float dA[ntx][tile_C::ne/2][8];
+ float mA[ntx][tile_C::ne/2][8];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+ const int k0 = k00 + k01;
+
+ load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K);
+ }
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) {
+ const int k0 = k00 + k01;
+
+ const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]);
+
+ dA[n][l][k01/(QI8_1/2)] = dm.x;
+ mA[n][l][k01/(QI8_1/2)] = dm.y;
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ float2 dB[tile_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
+
+ dB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K]);
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) {
+ tile_B B[2];
+
+ // Here load_generic is faster than load_ldmatrix.
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K);
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K);
+
+ tile_C Cm[2];
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tile_A A1;
+ A1.x[0] = 0x01010101;
+ A1.x[1] = 0x01010101;
+ mma(Cm[0], A1, B[0]);
+ mma(Cm[1], A1, B[1]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C Cd[2];
+
+ mma(Cd[0], A[n][k01/4 + 0], B[0]);
+ mma(Cd[1], A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1];
+ if (k01 >= MMQ_TILE_NE_K * 3/4) {
+ tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1];
+ }
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y);
+ }
+ }
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) {
+ float2 sB[tile_C::ne/2];
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
+
+ sB[l] = __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 0]*sB[l%2].x;
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] -= mA[n][l/2][k01/4 + 1]*sB[l%2].y;
+ }
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, sum, k00);
+ NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_df + txs.dm);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+ const int x_ql_0 = get_int_b2(bxi->qs, kqsx);
+ const int x_qh_0 = get_int_b2(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
+
+#pragma unroll
+ for (int l = 0; l < QR3_K; ++l) {
+ const int k = (kqsx/8)*32 + l*8 + kqsx % 8;
+
+ const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
+ const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
+
+ const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+ }
+
+ constexpr int rows_per_warp = warp_size / 4;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+ const int ksc = threadIdx.x % 4;
+
+ const int ksc_low = ksc % (QI3_K/8);
+ const int shift_low = 4 * (ksc / (QI3_K/8));
+ const int sc_low = (get_int_b2(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
+
+ const int ksc_high = QI3_K/8;
+ const int shift_high = 2 * ksc;
+ const int sc_high = ((get_int_b2(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
+
+ const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ const int8_t * sc8 = (const int8_t *) &sc;
+ const float d = bxi->d;
+
+#pragma unroll
+ for (int l = 0; l < int(sizeof(int)); ++l) {
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l];
+ }
+#else
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+#if !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE))
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride;
+
+ x_df[i] = bxi->d;
+ }
+#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) || defined(AMD_WMMA_AVAILABLE)
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_df + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4;
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq(
+ &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales,
+ x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) {
+ // scale arrangement after the following two lines:
+ // - ksc == 0: sc0, sc1, sc2, sc3
+ // - ksc == 1: sc4, sc5, sc6, sc7
+ // - ksc == 2: m0, m1, m2, m3
+ // - ksc == 3: m4, m5, m6, m7
+ return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits
+ ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_dm + txs.dm);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+ const int qs0 = get_int_b4(bxi->qs, txi);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F;
+#else
+ x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr int rows_per_warp = warp_size / 2;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ // Need if on AMD instead of % because warp_size == 64
+ // This causes double work and throughput loss (MI300X)
+ // H100 loses about 100 t/s with 'if' condition over '%'
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
+ if (i < mmq_y) {
+#else
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
+ {
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % 2;
+
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
+
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+ #pragma unroll
+ for (int l = 0; l < sizeof(int); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
+ }
+ }
+#else
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride;
+
+ x_dm[i] = bxi->dm;
+ }
+ constexpr int rows_per_warp = warp_size / 4;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8);
+
+ const int * scales = (const int *) bxi->scales;
+
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
+ }
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_dm + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16);
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq(
+ &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_dm = (half2 *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_dm + txs.dm);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+ const int ky = QR5_K*txi;
+
+ const int ql = get_int_b4(bxi->qs, txi);
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+ const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4));
+ const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010;
+ const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010;
+
+ const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0;
+ const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr int rows_per_warp = warp_size / 2;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+#if defined(AMD_MFMA_AVAILABLE)
+ // Need if on AMD instead of % because warp_size == 64
+ // This causes double work and throughput loss (MI300X)
+ // H100 loses about 100 t/s with 'if' condition over '%'
+ int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2;
+ if (i < mmq_y) {
+#else
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
+ {
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+ const int ksc = threadIdx.x % 2;
+
+ const int sc32 = unpack_scales_q45_K(scales, ksc + 0);
+ const int m32 = unpack_scales_q45_K(scales, ksc + 2);
+
+ const uint8_t * sc8 = (const uint8_t *) &sc32;
+ const uint8_t * m8 = (const uint8_t *) &m32;
+
+ const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
+
+#pragma unroll
+ for (int l = 0; l < int(sizeof(int)); ++l) {
+ x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
+ }
+ }
+ }
+#else
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ x_dm[i] = bxi->dm;
+ }
+
+ constexpr int rows_per_warp = warp_size / 4;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride;
+
+ const int * scales = (const int *) bxi->scales;
+
+ const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8);
+ const int scales8 = unpack_scales_q45_K(scales, ksc);
+
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8;
+ }
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const half2 * x_dm = (const half2 *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_dm + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const half2 * y_ds = (const half2 *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16);
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq(
+ &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8,
+ x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+ int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+ int * x_sc = (int *) (x_df + txs.dm);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
+
+ const int ql = get_int_b2(bxi->ql, txi);
+ const int ql0 = (ql >> 0) & 0x0F0F0F0F;
+ const int ql1 = (ql >> 4) & 0x0F0F0F0F;
+
+ const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4));
+ const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030;
+ const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030;
+
+ const int kq0 = 2*txi - txi % (QI6_K/2) + 0;
+ const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) {
+ int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d;
+#else
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int rows_per_warp = warp_size / 4;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) {
+ int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8));
+#else
+ x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8));
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + txs.qs;
+ const int * x_sc = (const int *) x_df + txs.dm;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+// #pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]);
+
+ sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq(
+ &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc,
+ x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]);
+ }
+ }
+ }
+}
+
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
+ const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
+#if defined(AMD_MFMA_AVAILABLE)
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 8, int, input_layout> tile_A;
+ typedef tile<16, 8, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+ typedef tile<64, 2, int, input_layout> tile_load;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B[1];
+ load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2;
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B[0]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
+ }
+ }
+ }
+ }
+#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
+ constexpr data_layout input_layout = get_input_data_layout();
+ typedef tile<16, 4, int, input_layout> tile_A;
+ typedef tile<16, 4, int, input_layout> tile_B;
+ typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) {
+ const int k0 = k00 + k01;
+
+ tile_A A[ntx];
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K);
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ tile_B B;
+ load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K);
+
+ const int j = j0 + tile_C::get_j(0);
+ const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C;
+ mma(C, A[n], B);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+ const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16);
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB;
+ }
+ }
+ }
+ }
+#elif defined(TURING_MMA_AVAILABLE)
+
+ typedef tile<16, 4, int> tile_A;
+ typedef tile< 8, 4, int> tile_B;
+ typedef tile<16, 8, int> tile_C;
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int rows_per_warp = 2 * granularity;
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K);
+
+ const int * x_qs = (const int *) x;
+ const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2;
+ const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K;
+ const int * y_qs = (const int *) y + 4;
+ const float * y_df = (const float *) y;
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_A::I);
+
+ tile_A A[ntx][8];
+ int scA[ntx][tile_C::ne/2][8];
+ float dA[ntx][tile_C::ne/2];
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
+ const int k0 = k00 + k01;
+
+ load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K);
+ load_ldmatrix(A[n][k01/4 + 1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + tile_A::J), MMQ_MMA_TILE_X_K_Q6_K);
+ }
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) {
+ const int k0 = k00 + k01;
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
+
+ const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + k0/16];
+ const int8_t * sc = (const int8_t *) &sc_packed;
+
+#pragma unroll
+ for (int ksc = 0; ksc < sizeof(int); ++ksc) {
+ scA[n][l][k01/4 + ksc] = sc[ksc];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int i = i0 + n*tile_C::I + tile_C::get_i(2*l);
+
+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q6_K];
+ }
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+ float tmp[ntx][tile_C::ne] = {{0.0f}};
+
+#pragma unroll
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) {
+ tile_B B[2];
+ float dB[tile_C::ne/2];
+
+ // Here load_generic is faster than load_ldmatrix.
+ load_generic(B[0], y_qs + j0*MMQ_TILE_Y_K + 0 + k01, MMQ_TILE_Y_K);
+ load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + tile_B::J + k01, MMQ_TILE_Y_K);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne/2; ++l) {
+ const int j = j0 + tile_C::get_j(l);
+
+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+ tile_C C[2];
+ mma(C[0], A[n][k01/4 + 0], B[0]);
+ mma(C[1], A[n][k01/4 + 1], B[1]);
+
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ tmp[n][l] += (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1])*dB[l%2];
+ }
+ }
+ }
+
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp[n][l]*dA[n][l/2];
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, sum, k00);
+ NO_DEVICE_CODE;
+#endif // AMD_MFMA_AVAILABLE || AMD_WMMA_AVAILABLE
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI4_NL;
+ const int kqsx = txi % QI4_NL;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b2(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
+ const int k0 = kbx * (2 * QI4_NL) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d);
+#else
+ x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride;
+
+ const int q2 = get_int_b2(bxi->qs, 2*kqsx+0);
+ const uint8_t * aux8 = (const uint8_t *) &q2;
+ const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1);
+
+#pragma unroll
+ for (int l = 0; l < QR2_XXS; ++l) {
+ const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
+ const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+
+ const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+ const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+
+ const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+ const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ const int ls = aux32 >> 28;
+ const float d = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4;
+#else
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride;
+
+ const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint16_t * q2 = (const uint16_t *) &q2_packed;
+
+ #pragma unroll
+ for (int l = 0; l < QR2_XS; ++l) {
+ const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9));
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ const int ls = bxi->scales[kqsx];
+ const float d = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#else
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride;
+
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+ for (int l = 0; l < QR2_S; ++l) {
+ const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300)));
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ const int ls = bxi->scales[kqsx];
+ const float d = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#else
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4;
+ x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride;
+
+ const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint8_t * q3 = (const uint8_t *) &q3_packed;
+ const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx);
+
+#pragma unroll
+ for (int l = 0; l < QR3_XXS; ++l) {
+ const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+
+ const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ const int ls = aux32 >> 28;
+ const float d = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2;
+#else
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2;
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride;
+
+ const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1));
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ const int signs_packed_32 = get_int_b2(bxi->signs, kqsx);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+#pragma unroll
+ for (int l = 0; l < QR3_S; ++l) {
+ const int2 grid_pos = make_int2(
+ iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)],
+ iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]);
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F);
+ const float d = bxi->d;
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d;
+#else
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y);
+ int * x_qs = (int *) x_tile;
+ half2 * x_ds = (half2 *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) {
+ int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride;
+
+ const int qs_packed = get_int_b2(bxi->qs, kqsx);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bxi->qh[kqsx];
+
+ #pragma unroll
+ for (int l = 0; l < QR1_S/2; ++l) {
+ const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)];
+
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1);
+ const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta);
+#else
+ x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int kqsx = threadIdx.x % threads_per_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+
+ const int aux_q4 = get_int_b4(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
+ const int k0 = 8 * (kqsx / 4) + kqsx % 4;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+
+ constexpr int rows_per_warp = warp_size / 8;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
+
+ const float d = __half2float(bxi->d);
+
+ const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F)
+ | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32);
+#else
+ x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ }
+}
+
+template<int mmq_x, int mmq_y, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_dp4a(
+ const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
+ const int stride, const int i_max, const int j_max) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j > j_max) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
+ }
+ }
+}
+
+template<ggml_type type, int mmq_x, int mmq_y, bool need_check>
+static __device__ __forceinline__ void mmq_write_back_mma(
+ const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
+ const int stride, const int i_max, const int j_max) {
+
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
+ constexpr int nwarps = mmq_get_nwarps_device();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr int tileC_IJ = mmq_get_granularity_device(0);
+ typedef tile<tileC_IJ, tileC_IJ, int, DATA_LAYOUT_J_MAJOR> tile_C;
+ constexpr int rows_per_warp = granularity;
+#else
+ typedef tile<16, 8, int> tile_C;
+ constexpr int rows_per_warp = 2 * granularity;
+#endif // defined(AMD_MFMA_AVAILABLE)
+ constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.
+
+ const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y");
+#else
+ GGML_UNUSED(nwarps);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) {
+#pragma unroll
+ for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int j = j0 + (threadIdx.y % ntx) * tile_C::J + tile_C::get_j(l);
+
+ if (j > j_max) {
+ continue;
+ }
+
+ const int i = i0 + n*tile_C::I + tile_C::get_i(l);
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
+ }
+ }
+ }
+}
+
+// -------------------------------------------------------------------------------------------------------------------------------------
+
+template <int mmq_x, int mmq_y, bool need_check, ggml_type type>
+struct mmq_type_traits;
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> {
+ static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_DS4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_1> {
+ static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_0> {
+ static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_1> {
+ static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
+ static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
+ static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
+#ifdef BLACKWELL_MMA_AVAILABLE
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
+#else
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+#endif // BLACKWELL_MMA_AVAILABLE
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
+ static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q3_K> {
+ static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_K> {
+ static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q5_K> {
+ static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q6_K> {
+ static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XXS> {
+ static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_XS> {
+ static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ2_S> {
+ static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_XXS> {
+ static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ3_S> {
+ static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ1_S> {
+ static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma<mmq_x, mmq_y>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_NL> {
+ static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_IQ4_XS> {
+ static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
+template <ggml_type type, int mmq_x, bool need_check, bool fixup>
+static __device__ __forceinline__ void mul_mat_q_process_tile(
+ const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
+ const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+ const int stride_row_x, const int ncols_y, const int stride_col_dst,
+ const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int mmq_y = get_mmq_y_device();
+ constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, need_check, type>::load_tiles;
+
+ extern __shared__ int data_mul_mat_q[];
+ int * tile_y = data_mul_mat_q + mmq_x;
+ int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_mma;
+ constexpr mmq_write_back_t write_back = mmq_write_back_mma<type, mmq_x, mmq_y, need_check>;
+#else
+ constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
+ constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+#if defined(BLACKWELL_MMA_AVAILABLE)
+ // FP4 tile stores 8 blocks
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
+#else
+ constexpr int ne_block = 4 * QK8_1;
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
+
+ constexpr int ITER_K = get_iter_k(type);
+ constexpr int blocks_per_iter = ITER_K / qk;
+
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
+
+ constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
+
+ for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
+ load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
+ {
+ const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
+#pragma unroll
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
+
+ tile_y[l] = by0[l];
+ }
+ }
+
+ __syncthreads();
+
+ vec_dot(tile_x, tile_y, sum, 0);
+
+ __syncthreads();
+
+ {
+ const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
+#pragma unroll
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
+ int l = l0 + threadIdx.y*warp_size + threadIdx.x;
+
+ tile_y[l] = by0[l];
+ }
+ }
+
+ __syncthreads();
+
+ vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K);
+
+ __syncthreads();
+ }
+
+ if (fixup) {
+ write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
+ } else {
+ write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
+ }
+}
+
+
+// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
+
+template <ggml_type type, int mmq_x, bool need_check>
+#if defined(GGML_USE_HIP)
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#else
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1)
+#else
+ __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2)
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
+#endif // defined(GGML_USE_HIP)
+static __global__ void mul_mat_q(
+ const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
+ const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
+ const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
+ const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ncols_max) {
+
+ // Skip unused template specializations for faster compilation:
+ if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int mmq_y = get_mmq_y_device();
+
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
+
+ // Initialize the ids for writing back data with just the index.
+ // For regular matrix multiplications this is never changed.
+ // For MoE the correct indices are loaded from ids_dst.
+ extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
+
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
+ break;
+ }
+
+ ids_dst_shared[j] = j;
+ }
+ __syncthreads();
+
+ // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
+#if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+ {
+ const int wt = blockIdx.z / nchannels_y;
+ const int zt = blockIdx.z - wt*nchannels_y;
+ const int jt = blockIdx.y;
+ const int it = blockIdx.x;
+
+ // Defaults for regular matrix multiplication:
+ int col_low = 0;
+ int col_high = ncols_dst;
+ int col_diff = ncols_dst;
+ int offset_y = wt*stride_sample_y + zt*stride_channel_y;
+ int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
+
+ if (ids_dst) {
+ col_low = expert_bounds[zt + 0];
+ col_high = expert_bounds[zt + 1];
+ col_diff = col_high - col_low;
+
+ offset_y = 0;
+ offset_dst = 0;
+
+ if (jt*mmq_x >= col_diff) {
+ return;
+ }
+
+ // __syncthreads(); // There is no previous tile that could cause a race condition.
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
+
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
+ break;
+ }
+
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
+ }
+ __syncthreads();
+ }
+
+ offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+ offset_dst += it*mmq_y;
+
+ const int tile_x_max_i = nrows_x - it*mmq_y - 1;
+ const int tile_y_max_j = col_diff - jt*mmq_x - 1;
+
+ const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
+
+ constexpr bool fixup = false;
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
+ tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
+ return;
+ }
+#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
+
+ constexpr int ITER_K = get_iter_k(type);
+
+ const int64_t blocks_per_ne00 = ncols_x / qk;
+ constexpr int blocks_per_iter = ITER_K / qk;
+
+ // kbc == k block continuous, current index in continuous ijk space.
+ int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+ int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+
+ kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
+ kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
+
+ // kb0 == k index when doing the matrix multiplication for an output tile.
+ int kb0_start = kbc % blocks_per_ne00;
+ int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
+ while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
+ int tmp = kbc;
+ const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+ tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+ const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
+ tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
+ const int zt = tmp / (ntx*blocks_per_ne00);
+ tmp -= zt * (ntx*blocks_per_ne00);
+ const int jt = tmp / blocks_per_ne00;
+
+ // Defaults for regular matrix multiplication:
+ int col_low = 0;
+ int col_high = ncols_dst;
+ int col_diff = ncols_dst;
+ int offset_y = wt*stride_sample_y + zt*stride_channel_y;
+ int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
+
+ if (ids_dst) {
+ col_low = expert_bounds[zt + 0];
+ col_high = expert_bounds[zt + 1];
+ col_diff = col_high - col_low;
+
+ offset_y = 0;
+ offset_dst = 0;
+
+ if (jt*mmq_x >= col_diff) {
+ kbc += blocks_per_ne00;
+ kbc -= kbc % blocks_per_ne00;
+
+ kb0_start = 0;
+ kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
+
+ continue;
+ }
+
+ __syncthreads();
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
+
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
+ break;
+ }
+
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
+ }
+ __syncthreads();
+ }
+
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
+ offset_dst += it*mmq_y;
+
+ const int tile_x_max_i = nrows_x - it*mmq_y - 1;
+ const int tile_y_max_j = col_diff - jt*mmq_x - 1;
+
+ const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
+
+ constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
+ tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
+
+ kbc += blocks_per_ne00;
+ kbc -= kbc % blocks_per_ne00;
+
+ kb0_start = 0;
+ kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
+ }
+
+ if (kbc >= kbc_stop) {
+ return;
+ }
+
+ int tmp = kbc;
+ const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+ tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+ const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
+ tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
+ const int zt = tmp / (ntx*blocks_per_ne00);
+ tmp -= zt * (ntx*blocks_per_ne00);
+ const int jt = tmp / blocks_per_ne00;
+
+ // Defaults for regular matrix multiplication:
+ int col_low = 0;
+ int col_high = ncols_dst;
+ int col_diff = ncols_dst;
+ int offset_y = wt*stride_sample_y + zt*stride_channel_y;
+ int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
+
+ if (ids_dst) {
+ col_low = expert_bounds[zt + 0];
+ col_high = expert_bounds[zt + 1];
+ col_diff = col_high - col_low;
+
+ offset_y = 0;
+ offset_dst = 0;
+
+ if (jt*mmq_x >= col_diff) {
+ return;
+ }
+
+ // The memory layout for the fixup buffer is always contiguous, therefore reset ids:
+ __syncthreads();
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) {
+ const int j = j0 + threadIdx.y*warp_size + threadIdx.x;
+
+ if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) {
+ break;
+ }
+
+ ids_dst_shared[j] = j;
+ }
+ __syncthreads();
+ }
+
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
+ offset_dst += it*mmq_y;
+
+ const int tile_x_max_i = nrows_x - it*mmq_y - 1;
+ const int tile_y_max_j = col_diff - jt*mmq_x - 1;
+
+ const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
+
+ constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
+ mul_mat_q_process_tile<type, mmq_x, need_check, fixup>
+ (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst,
+ tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
+}
+
+template <ggml_type type, int mmq_x, bool need_check>
+static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
+ const int32_t * expert_bounds,
+ float * __restrict__ dst,
+ const float * __restrict__ tmp_last_tile,
+ const int ncols_x,
+ const int nrows_x,
+ const int ncols_dst,
+ const size_t stride_col_dst,
+ const int nchannels_y,
+ const size_t stride_channel_dst,
+ const int nsamples_y,
+ const size_t stride_sample_dst,
+ const int ncols_max) {
+ constexpr int mmq_y = get_mmq_y_device();
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int ITER_K = get_iter_k(type);
+
+ constexpr int blocks_per_iter = ITER_K / qk;
+ const int64_t blocks_per_ne00 = ncols_x / qk;
+
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
+
+ const int ntx = (ncols_max + mmq_x - 1) / mmq_x;
+ const int nty = (nrows_x + mmq_y - 1) / mmq_y;
+
+ const int bidx0 = blockIdx.x;
+
+ // kbc == k block continuous, current index in continuous ijk space.
+ int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+ int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+
+ kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
+ kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
+
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
+ const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
+ const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+ return;
+ }
+
+ bool any_fixup = false;
+
+ // Iterate over previous blocks and sum up partial sums written to fixup buffer.
+ // All CUDA blocks that get here must have a previous block that needs a fixup.
+ int64_t bidx = bidx0 - 1;
+ int64_t kbc_stop = kbc0;
+ while(true) {
+ int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
+ kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
+
+ if (kbc == kbc_stop) { // Did not have any data.
+ bidx--;
+ kbc_stop = kbc;
+ continue;
+ }
+
+ any_fixup = true;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
+ }
+ }
+
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
+ if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
+ break;
+ }
+ bidx--;
+ kbc_stop = kbc;
+ }
+
+ if (!any_fixup) {
+ return;
+ }
+
+ int tmp = kbc0;
+ const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+ tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
+ const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
+ tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
+ const int zt = tmp / (ntx*blocks_per_ne00);
+ tmp -= zt * (ntx*blocks_per_ne00);
+ const int jt = tmp / blocks_per_ne00;
+
+ if (!ids_dst) {
+ const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
+ dst += offset_dst;
+
+ const int i_max = nrows_x - it*mmq_y - 1;
+ const int j_max = ncols_dst - jt*mmq_x - 1;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j > j_max) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
+ }
+ }
+ return;
+ }
+
+ __shared__ int ids_dst_shared[mmq_x];
+ const int col_low = expert_bounds[zt + 0];
+ const int col_high = expert_bounds[zt + 1];
+ const int col_diff = col_high - col_low;
+
+ for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) {
+ ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
+ }
+ __syncthreads();
+
+ const int offset_dst = it*mmq_y;
+ dst += offset_dst;
+
+ const int i_max = nrows_x - it*mmq_y - 1;
+ const int j_max = col_diff - jt*mmq_x - 1;
+
+#pragma unroll
+ for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j > j_max) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += warp_size) {
+ const int i = i0 + threadIdx.x;
+
+ if (need_check && i > i_max) {
+ continue;
+ }
+
+ dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size];
+ }
+ }
+}
+
+struct mmq_args {
+ const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
+ int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
+ int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
+ int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
+ bool use_stream_k; int64_t ncols_max;
+};
+
+template<ggml_type type>
+static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) {
+ const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
+ const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
+ const size_t nbs_ids = mmq_x*sizeof(int);
+ const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
+ const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
+ return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
+}
+
+template <ggml_type type, int mmq_x>
+static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
+ const int mmq_y = get_mmq_y_host(cc);
+
+ const dim3 block_dims(warp_size, nwarps, 1);
+
+ const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps);
+
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, false>), nbytes_shared);
+ CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x, true>), nbytes_shared);
+
+ const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
+ const int ntx = (args.ncols_max + mmq_x - 1) / mmq_x;
+ const int ntzw = args.nchannels_y * args.nsamples_y;
+ const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
+
+ GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
+ GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0);
+ const int channel_ratio = args.nchannels_y / args.nchannels_x;
+ const int sample_ratio = args.nsamples_y / args.nsamples_x;
+
+ if (!args.use_stream_k) {
+ if (args.nrows_x % mmq_y == 0) {
+ constexpr bool need_check = false;
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
+ } else {
+ constexpr bool need_check = true;
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
+ }
+ return;
+ }
+
+ const dim3 block_nums_stream_k(nsm, 1, 1);
+ const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
+
+ ggml_cuda_pool & pool = ctx.pool(id);
+ ggml_cuda_pool_alloc<float> tmp_fixup(pool);
+ if (fixup_needed) {
+ tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
+ }
+
+ if (args.nrows_x % mmq_y == 0) {
+ constexpr bool need_check = false;
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
+
+ if (!fixup_needed) {
+ return;
+ }
+
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
+ } else {
+ constexpr bool need_check = true;
+ mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
+ (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
+ args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
+ channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
+ sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+ args.ncols_max);
+
+ if (!fixup_needed) {
+ return;
+ }
+
+ mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
+ (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
+ args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+ args.ncols_max);
+ }
+}
+
+template <ggml_type type>
+void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+ const int warp_size = ggml_cuda_info().devices[id].warp_size;
+ const int nwarps = mmq_get_nwarps_host(cc, warp_size);
+
+ const int mmq_x_max = get_mmq_x_max_host(cc);
+ const int mmq_y = get_mmq_y_host(cc);
+
+ int mmq_x_best = 0;
+ int ntiles_x_best = INT_MAX;
+
+ for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
+ const int granularity = mmq_get_granularity_host(mmq_x, cc);
+
+ if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) {
+ continue;
+ }
+
+ const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
+
+ if (ntiles_x < ntiles_x_best) {
+ mmq_x_best = mmq_x;
+ ntiles_x_best = ntiles_x;
+ }
+ }
+
+ switch (mmq_x_best) {
+ case 8:
+ launch_mul_mat_q<type, 8>(ctx, args, stream);
+ break;
+ case 16:
+ launch_mul_mat_q<type, 16>(ctx, args, stream);
+ break;
+ case 24:
+ launch_mul_mat_q<type, 24>(ctx, args, stream);
+ break;
+ case 32:
+ launch_mul_mat_q<type, 32>(ctx, args, stream);
+ break;
+ case 40:
+ launch_mul_mat_q<type, 40>(ctx, args, stream);
+ break;
+ case 48:
+ launch_mul_mat_q<type, 48>(ctx, args, stream);
+ break;
+ case 56:
+ launch_mul_mat_q<type, 56>(ctx, args, stream);
+ break;
+ case 64:
+ launch_mul_mat_q<type, 64>(ctx, args, stream);
+ break;
+ case 72:
+ launch_mul_mat_q<type, 72>(ctx, args, stream);
+ break;
+ case 80:
+ launch_mul_mat_q<type, 80>(ctx, args, stream);
+ break;
+ case 88:
+ launch_mul_mat_q<type, 88>(ctx, args, stream);
+ break;
+ case 96:
+ launch_mul_mat_q<type, 96>(ctx, args, stream);
+ break;
+ case 104:
+ launch_mul_mat_q<type, 104>(ctx, args, stream);
+ break;
+ case 112:
+ launch_mul_mat_q<type, 112>(ctx, args, stream);
+ break;
+ case 120:
+ launch_mul_mat_q<type, 120>(ctx, args, stream);
+ break;
+ case 128:
+ launch_mul_mat_q<type, 128>(ctx, args, stream);
+ break;
+ default:
+ fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+#define DECL_MMQ_CASE(type) \
+ template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
+
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
+extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
+extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
+extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
+extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
+
+// -------------------------------------------------------------------------------------------------------------------------
+
+void ggml_cuda_mul_mat_q(
+ ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
+
+void ggml_cuda_op_mul_mat_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts);
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvf.cu b/llama.cpp/ggml/src/ggml-cuda/mmvf.cu
new file mode 100644
index 0000000..d914720
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmvf.cu
@@ -0,0 +1,862 @@
+#include "ggml.h"
+#include "common.cuh"
+#include "unary.cuh"
+#include "mmvf.cuh"
+#include "convert.cuh"
+
+template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
+static __global__ void mul_mat_vec_f(
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
+ const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+ const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ids_stride) {
+ const int row = blockIdx.x;
+ // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
+ const int channel_dst = blockIdx.y;
+ const int tid = threadIdx.x;
+
+ int token_idx;
+ int channel_x;
+ int channel_y;
+ int sample_dst;
+
+ if constexpr (is_multi_token_id) {
+ // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+ token_idx = blockIdx.z;
+ channel_x = ids[channel_dst + token_idx * ids_stride];
+ channel_y = fastmodulo(channel_dst, nchannels_y);
+ sample_dst = 0;
+ } else {
+ token_idx = ids ? blockIdx.z : 0;
+ channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
+ channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
+ sample_dst = ids ? 0 : blockIdx.z;
+ }
+
+ const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
+ const int sample_y = sample_dst;
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
+ y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
+ dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+ if constexpr (is_multi_token_id) {
+ y += token_idx*stride_col_y2*2;
+ dst += token_idx*stride_col_dst;
+ }
+
+ bool use_gate = false;
+ bool use_bias = false;
+ bool use_gate_bias = false;
+ ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
+ const T * gate_x = nullptr;
+ const float * x_bias = nullptr;
+ const float * gate_bias = nullptr;
+
+ if constexpr (has_fusion) {
+ use_gate = fusion.gate != nullptr;
+ use_bias = fusion.x_bias != nullptr;
+ use_gate_bias = fusion.gate_bias != nullptr;
+ glu_op = fusion.glu_op;
+
+ if (use_gate) {
+ gate_x = static_cast<const T *>(fusion.gate);
+ }
+ if (use_bias) {
+ x_bias = static_cast<const float *>(fusion.x_bias);
+ }
+ if (use_gate_bias) {
+ gate_bias = static_cast<const float *>(fusion.gate_bias);
+ use_gate_bias = use_gate;
+ } else {
+ use_gate_bias = false;
+ }
+ }
+
+ if (use_gate) {
+ gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
+ }
+
+ const int channel_bias = ids ? channel_x : channel_dst;
+
+ if constexpr (has_fusion) {
+ if (use_bias) {
+ x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
+ }
+ if (use_gate_bias) {
+ gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
+ }
+ }
+
+ const float2 * y2 = (const float2 *) y;
+
+ extern __shared__ char data_mmv[];
+ float * buf_iw = (float *) data_mmv;
+ float * buf_iw_gate = nullptr;
+ if constexpr (has_fusion) {
+ buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
+ }
+
+ if (block_size > warp_size) {
+ if (tid < warp_size) {
+ buf_iw[tid] = 0.0f;
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ buf_iw_gate[tid] = 0.0f;
+ }
+ }
+ }
+ __syncthreads();
+ }
+
+ float sumf[ncols_dst] = {0.0f};
+ float sumf_gate[ncols_dst];
+ if constexpr (has_fusion) {
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ sumf_gate[j] = 0.0f;
+ }
+ }
+
+ if constexpr (std::is_same_v<T, float>) {
+ const float2 * x2 = (const float2 *) x;
+ const float2 * gate_x2 = nullptr;
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ gate_x2 = (const float2 *) gate_x;
+ }
+ }
+
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const float2 tmpx = x2[col2];
+ float2 tmpx_gate = make_float2(0.0f, 0.0f);
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmpx_gate = gate_x2[col2];
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
+ }
+ }
+ }
+ }
+ } else if constexpr (std::is_same_v<T, half>) {
+ const half2 * x2 = (const half2 *) x;
+ const half2 * gate_x2 = nullptr;
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ gate_x2 = (const half2 *) gate_x;
+ }
+ }
+
+ if (std::is_same_v<type_acc, float>) {
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const float2 tmpx = __half22float2(x2[col2]);
+ float2 tmpx_gate = make_float2(0.0f, 0.0f);
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmpx_gate = __half22float2(gate_x2[col2]);
+ }
+ }
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
+ }
+ }
+ }
+ }
+ } else {
+#ifdef FP16_AVAILABLE
+ half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
+ half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
+
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const half2 tmpx = x2[col2];
+ half2 tmpx_gate = make_half2(0.0f, 0.0f);
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmpx_gate = gate_x2[col2];
+ }
+ }
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
+
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
+ }
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
+ }
+
+ if constexpr (has_fusion) {
+ if (use_gate) {
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
+ }
+ }
+ }
+#else
+ NO_DEVICE_CODE;
+#endif // FP16_AVAILABLE
+ }
+ } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
+//TODO: add support for ggml_cuda_mad for hip_bfloat162
+#if defined(GGML_USE_HIP)
+ const int * x2 = (const int *) x;
+ const int * gate_x2 = nullptr;
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ gate_x2 = (const int *) gate_x;
+ }
+ }
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const int tmpx = x2[col2];
+ int tmpx_gate = 0;
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmpx_gate = gate_x2[col2];
+ }
+ }
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
+ const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
+ ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
+
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
+ const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
+ ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
+ ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
+ }
+ }
+ }
+ }
+#else
+ const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
+ const nv_bfloat162 * gate_x2 = nullptr;
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ gate_x2 = (const nv_bfloat162 *) gate_x;
+ }
+ }
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const nv_bfloat162 tmpx = x2[col2];
+ nv_bfloat162 tmpx_gate;
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmpx_gate = gate_x2[col2];
+ }
+ }
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
+ ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
+ }
+ }
+ }
+ }
+#endif
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
+
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
+ }
+ }
+
+ if (block_size > warp_size) {
+ buf_iw[tid/warp_size] = sumf[j];
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ buf_iw_gate[tid/warp_size] = sumf_gate[j];
+ }
+ }
+ __syncthreads();
+ if (tid < warp_size) {
+ sumf[j] = buf_iw[tid];
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ sumf_gate[j] = buf_iw_gate[tid];
+ sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
+ }
+ }
+ }
+
+ if (j < ncols_dst) {
+ __syncthreads();
+ }
+ }
+ }
+
+ if (tid >= ncols_dst) {
+ return;
+ }
+
+ float value = sumf[tid];
+
+ if constexpr (has_fusion) {
+ if (use_bias) {
+ value += x_bias[tid*stride_col_dst + row];
+ }
+
+ if (use_gate) {
+ float gate_value = sumf_gate[tid];
+ if (use_gate_bias) {
+ gate_value += gate_bias[tid*stride_col_dst + row];
+ }
+ switch (glu_op) {
+ case GGML_GLU_OP_SWIGLU:
+ value *= ggml_cuda_op_silu_single(gate_value);
+ break;
+ case GGML_GLU_OP_GEGLU:
+ value *= ggml_cuda_op_gelu_single(gate_value);
+ break;
+ case GGML_GLU_OP_SWIGLU_OAI: {
+ value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ }
+
+ dst[tid*stride_col_dst + row] = value;
+
+ if constexpr (!has_fusion) {
+ GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
+ }
+}
+
+template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
+static void mul_mat_vec_f_switch_fusion(
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+ const int64_t ncols, const uint3 nchannels_y,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
+
+ const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+ if constexpr (ncols_dst == 1) {
+ if (has_fusion) {
+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
+ return;
+ }
+ }
+
+ GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
+
+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
+
+}
+
+template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
+void launch_mul_mat_vec_f_cuda(
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+ const int64_t ncols, const int64_t nrows,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
+ GGML_ASSERT(ncols % 2 == 0);
+ GGML_ASSERT(stride_row % 2 == 0);
+ GGML_ASSERT(stride_col_y % 2 == 0);
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
+ const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
+ const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
+ const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
+
+ const int device = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+ int64_t block_size_best = warp_size;
+ int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
+ int64_t max_block_size = 256;
+ if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
+ max_block_size = 128;
+ }
+ for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
+ const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
+ if (niter < niter_best) {
+ niter_best = niter;
+ block_size_best = block_size;
+ }
+ }
+
+ const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+
+ const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
+ const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
+ const dim3 block_dims(block_size_best, 1, 1);
+ switch (block_size_best) {
+ case 32: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ case 64: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ case 96: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ case 128: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ case 160: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ case 192: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ case 224: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ case 256: {
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+}
+
+template <typename T, typename type_acc>
+static void mul_mat_vec_f_cuda_switch_ncols_dst(
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const int64_t ids_stride, cudaStream_t stream) {
+
+ const bool has_ids = ids != nullptr;
+
+ if (has_ids && ncols_dst > 1) {
+ // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+ constexpr int c_ncols_dst = 1;
+ launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ ncols_dst, ids_stride, stream);
+ return;
+ }
+
+ if (has_ids) {
+ // Single-token MUL_MAT_ID path
+ constexpr int c_ncols_dst = 1;
+ launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ ncols_dst, ids_stride, stream);
+ return;
+ }
+
+ switch (ncols_dst) {
+ case 1:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 1>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ case 2:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 2>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ case 3:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 3>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ case 4:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 4>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ case 5:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 5>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ case 6:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 6>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ case 7:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 7>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ case 8:
+ launch_mul_mat_vec_f_cuda<T, type_acc, 8>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+template<typename T>
+static void mul_mat_vec_f_cuda(
+ const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
+
+ if constexpr(std::is_same_v<T, half>) {
+ if (prec == GGML_PREC_DEFAULT) {
+ mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
+ (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ return;
+ }
+ }
+ mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
+ (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+}
+
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+ const ggml_cuda_mm_fusion_args_host * fusion) {
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ const size_t ts_src0 = ggml_type_size(src0->type);
+ const size_t ts_src1 = ggml_type_size(src1->type);
+ const size_t ts_dst = ggml_type_size(dst->type);
+
+ GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
+ GGML_ASSERT(ne13 == ne3);
+
+ GGML_ASSERT( nb00 == ts_src0);
+ GGML_ASSERT( nb10 == ts_src1);
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+ GGML_ASSERT( nb0 == ts_dst);
+
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+ const float * src1_d = (const float *) src1->data;
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
+ float * dst_d = (float *) dst->data;
+
+ ggml_cuda_mm_fusion_args_device fusion_local{};
+
+ if (fusion) {
+ GGML_ASSERT( !ids || dst->ne[2] == 1);
+ GGML_ASSERT( ids || dst->ne[1] == 1);
+ if (fusion->x_bias) {
+ GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
+ GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
+ GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
+ fusion_local.x_bias = fusion->x_bias->data;
+ }
+ if (fusion->gate) {
+ GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
+ fusion_local.gate = fusion->gate->data;
+ }
+ if (fusion->gate_bias) {
+ GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
+ GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
+ GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
+ fusion_local.gate_bias = fusion->gate_bias->data;
+ }
+ fusion_local.glu_op = fusion->glu_op;
+ }
+
+ const int64_t s01 = src0->nb[1] / ts_src0;
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s1 = dst->nb[1] / ts_dst;
+ const int64_t s02 = src0->nb[2] / ts_src0;
+ const int64_t s12 = src1->nb[2] / ts_src1;
+ const int64_t s2 = dst->nb[2] / ts_dst;
+ const int64_t s03 = src0->nb[3] / ts_src0;
+ const int64_t s13 = src1->nb[3] / ts_src1;
+ const int64_t s3 = dst->nb[3] / ts_dst;
+
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+ const int64_t ncols_dst = ids ? ne2 : ne1;
+ const int64_t nchannels_y = ids ? ne11 : ne12;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
+ const int64_t stride_col_dst = ids ? s2 : s1;
+ const int64_t stride_col_y = ids ? s12 : s11;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
+ const int64_t stride_channel_y = ids ? s11 : s12;
+
+ const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
+ switch (src0->type) {
+ case GGML_TYPE_F32: {
+ const float * src0_d = (const float *) src0->data;
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
+ } break;
+ case GGML_TYPE_F16: {
+ const half * src0_d = (const half *) src0->data;
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
+ } break;
+ case GGML_TYPE_BF16: {
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
+ } break;
+ default:
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+ }
+}
+
+void ggml_cuda_op_mul_mat_vec_f(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne0 = dst->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
+ const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
+
+ // ggml_cuda_op provides single, contiguous matrices
+ const int64_t stride_row = ne00;
+ const int64_t stride_col_y = ne10;
+ const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
+ const int64_t nchannels_x = 1;
+ const int64_t nchannels_y = 1;
+ const int64_t nchannels_dst = 1;
+ const int64_t stride_channel_x = 0;
+ const int64_t stride_channel_y = 0;
+ const int64_t stride_channel_dst = 0;
+ const int64_t nsamples_x = 1;
+ const int64_t nsamples_dst = 1;
+ const int64_t stride_sample_x = 0;
+ const int64_t stride_sample_y = 0;
+ const int64_t stride_sample_dst = 0;
+
+ ggml_cuda_mm_fusion_args_device empty{};
+ switch (src0->type) {
+ case GGML_TYPE_F32: {
+ const float * src0_d = (const float *) src0_dd_i;
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
+ } break;
+ case GGML_TYPE_F16: {
+ const half * src0_d = (const half *) src0_dd_i;
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
+ } break;
+ case GGML_TYPE_BF16: {
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
+ } break;
+ default:
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+ }
+
+ GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
+}
+
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
+ if (src0_ne[0] % 2 != 0) {
+ return false;
+ }
+
+ const size_t ts = ggml_type_size(type);
+ if (src0_nb[0] != ts) {
+ return false;
+ }
+
+ // Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (src0_nb[i] % (2*ts) != 0) {
+ return false;
+ }
+ }
+
+ switch (type) {
+ case GGML_TYPE_F32:
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ if (ampere_mma_available(cc)) {
+ return ne11 <= 3;
+ }
+ if (cc >= GGML_CUDA_CC_TURING) {
+ return ne11 <= 4;
+ }
+ return ne11 <= 3;
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+ if (fp32_mma_hardware_available(cc)) {
+ return ne11 <= 3;
+ }
+ return ne11 <= 8;
+ }
+ return ne11 <= 8;
+ case GGML_TYPE_F16:
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+ if (ampere_mma_available(cc)) {
+ return src0_small && ne11 == 1;
+ }
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ return src0_small && ne11 <= 4;
+ }
+ if (fp16_mma_hardware_available(cc)) {
+ return src0_small && ne11 <= 3;
+ }
+ return ne11 <= 8;
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+ if (fp16_mma_hardware_available(cc)) {
+ if (GGML_CUDA_CC_IS_RDNA3(cc)) {
+ return ne11 <= 3;
+ }
+ if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+ return ne11 <= 5;
+ }
+ return ne11 <= 2;
+ }
+ return ne11 <= 8;
+ }
+ return ne11 <= 8;
+ case GGML_TYPE_BF16:
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+ if (ampere_mma_available(cc)) {
+ return src0_small && ne11 == 1;
+ }
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ return src0_small && ne11 <= 4;
+ }
+ if (bf16_mma_hardware_available(cc)) {
+ return src0_small && ne11 <= 3;
+ }
+ return ne11 <= 8;
+ } else if (GGML_CUDA_CC_IS_AMD(cc)) {
+ if (bf16_mma_hardware_available(cc)) {
+ return ne11 <= 3;
+ }
+ return ne11 <= 8;
+ }
+ return ne11 <= 8;
+ default:
+ return false;
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh b/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh
new file mode 100644
index 0000000..a50f7c0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmvf.cuh
@@ -0,0 +1,14 @@
+#include "common.cuh"
+
+#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels.
+
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+ const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
+
+void ggml_cuda_op_mul_mat_vec_f(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11);
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvq.cu b/llama.cpp/ggml/src/ggml-cuda/mmvq.cu
new file mode 100644
index 0000000..ce25ccf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmvq.cu
@@ -0,0 +1,767 @@
+#include "mmvq.cuh"
+#include "quantize.cuh"
+#include "unary.cuh"
+#include "vecdotq.cuh"
+
+#include <cstdint>
+
+typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
+
+static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1;
+ case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1;
+ case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
+ case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
+ case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
+ case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
+ case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
+ case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
+ case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
+ case GGML_TYPE_Q5_K: return vec_dot_q5_K_q8_1;
+ case GGML_TYPE_Q6_K: return vec_dot_q6_K_q8_1;
+ case GGML_TYPE_IQ2_XXS: return vec_dot_iq2_xxs_q8_1;
+ case GGML_TYPE_IQ2_XS: return vec_dot_iq2_xs_q8_1;
+ case GGML_TYPE_IQ2_S: return vec_dot_iq2_s_q8_1;
+ case GGML_TYPE_IQ3_XXS: return vec_dot_iq3_xxs_q8_1;
+ case GGML_TYPE_IQ1_S: return vec_dot_iq1_s_q8_1;
+ case GGML_TYPE_IQ1_M: return vec_dot_iq1_m_q8_1;
+ case GGML_TYPE_IQ4_NL: return vec_dot_iq4_nl_q8_1;
+ case GGML_TYPE_IQ4_XS: return vec_dot_iq4_xs_q8_1;
+ case GGML_TYPE_IQ3_S: return vec_dot_iq3_s_q8_1;
+ default: return nullptr;
+ }
+}
+
+static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
+ case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
+ case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
+ case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
+ case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
+ case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
+ case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q5_K: return VDR_Q5_K_Q8_1_MMVQ;
+ case GGML_TYPE_Q6_K: return VDR_Q6_K_Q8_1_MMVQ;
+ case GGML_TYPE_IQ2_XXS: return VDR_IQ2_XXS_Q8_1_MMVQ;
+ case GGML_TYPE_IQ2_XS: return VDR_IQ2_XS_Q8_1_MMVQ;
+ case GGML_TYPE_IQ2_S: return VDR_IQ2_S_Q8_1_MMVQ;
+ case GGML_TYPE_IQ3_XXS: return VDR_IQ3_XXS_Q8_1_MMVQ;
+ case GGML_TYPE_IQ3_S: return VDR_IQ3_S_Q8_1_MMVQ;
+ case GGML_TYPE_IQ4_NL: return VDR_IQ4_NL_Q8_1_MMVQ;
+ case GGML_TYPE_IQ4_XS: return VDR_IQ4_XS_Q8_1_MMVQ;
+ default: return 1;
+ }
+}
+
+enum mmvq_parameter_table_id {
+ MMVQ_PARAMETERS_GENERIC = 0,
+ MMVQ_PARAMETERS_GCN,
+ MMVQ_PARAMETERS_RDNA2
+};
+
+static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
+#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
+ return MMVQ_PARAMETERS_RDNA2;
+#elif defined(GCN) || defined(CDNA)
+ return MMVQ_PARAMETERS_GCN;
+#else
+ return MMVQ_PARAMETERS_GENERIC;
+#endif
+}
+
+static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
+ if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
+ return MMVQ_PARAMETERS_RDNA2;
+ }
+ if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
+ return MMVQ_PARAMETERS_GCN;
+ }
+ return MMVQ_PARAMETERS_GENERIC;
+}
+
+static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
+ if (table_id == MMVQ_PARAMETERS_GENERIC) {
+ switch (ncols_dst) {
+ case 1:
+ case 2:
+ case 3:
+ case 4:
+ return 4;
+ case 5:
+ case 6:
+ case 7:
+ case 8:
+ return 2;
+ default:
+ return 1;
+ }
+ } else if (table_id == MMVQ_PARAMETERS_GCN) {
+ switch (ncols_dst) {
+ case 1:
+ case 2:
+ case 3:
+ case 4:
+ return 2;
+ case 5:
+ case 6:
+ case 7:
+ case 8:
+ default:
+ return 1;
+ }
+ }
+ return 1;
+}
+
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
+ if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
+ switch (ncols_dst) {
+ case 1:
+ return 1;
+ case 2:
+ case 3:
+ case 4:
+ case 5:
+ case 6:
+ case 7:
+ case 8:
+ return 2;
+ default:
+ return 1;
+ }
+ }
+ return 1;
+}
+
+template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
+__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mul_mat_vec_q(
+ const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
+ const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
+ const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
+ const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
+ const uint32_t ids_stride) {
+
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
+ constexpr int vdr = get_vdr_mmvq(type);
+ constexpr mmvq_parameter_table_id table_id = get_device_table_id();
+ constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
+ constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
+ const int tid = warp_size*threadIdx.y + threadIdx.x;
+ const int row0 = rows_per_cuda_block*blockIdx.x;
+ const int blocks_per_row_x = ncols_x / qk;
+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
+
+ const uint32_t channel_dst = blockIdx.y;
+
+ uint32_t token_idx = 0;
+ uint32_t channel_x;
+ uint32_t channel_y;
+ uint32_t sample_dst;
+
+ if constexpr (is_multi_token_id) {
+ // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+ token_idx = blockIdx.z;
+ channel_x = ids[channel_dst + token_idx * ids_stride];
+ channel_y = fastmodulo(channel_dst, nchannels_y);
+ sample_dst = 0;
+ } else {
+ channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
+ channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+ sample_dst = blockIdx.z;
+ }
+
+ const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
+ const uint32_t sample_y = sample_dst;
+
+ bool use_gate = false;
+ bool use_bias = false;
+ bool use_gate_bias = false;
+ const void * vgate = nullptr;
+ const float * x_bias = nullptr;
+ const float * gate_bias = nullptr;
+ ggml_glu_op active_glu;
+
+ if constexpr (has_fusion) {
+ use_gate = fusion.gate != nullptr;
+ use_bias = fusion.x_bias != nullptr;
+ use_gate_bias = fusion.gate_bias != nullptr && use_gate;
+ vgate = fusion.gate;
+ x_bias = (const float *) fusion.x_bias;
+ gate_bias = (const float *) fusion.gate_bias;
+ active_glu = fusion.glu_op;
+ }
+
+
+ float x_biases[ncols_dst] = { 0.0f };
+ float gate_biases[ncols_dst] = { 0.0f };
+ if constexpr (has_fusion) {
+ const uint32_t channel_bias = ids ? channel_x : channel_dst;
+ if (use_bias) {
+ x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
+ // 1. Hide latency by prefetching bias and gate here
+ // 2. load only on threads that won't die after partial sum calculation
+ if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
+ (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
+ }
+ }
+ }
+ if (use_gate_bias) {
+ gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
+ if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
+ (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
+ }
+ }
+ }
+ }
+
+ // partial sum for each thread
+ float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
+ float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
+
+ const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+ if constexpr (is_multi_token_id) {
+ y += token_idx*stride_col_y;
+ }
+ const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
+
+ for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+ const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
+
+ // x block quant index when casting the quants to int
+ const int kqs = vdr * (tid % (qi/vdr));
+
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+#pragma unroll
+ for (int i = 0; i < rows_per_cuda_block; ++i) {
+ tmp[j][i] += vec_dot_q_cuda(
+ vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmp_gate[j][i] += vec_dot_q_cuda(
+ vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
+ }
+ }
+ }
+ }
+ }
+
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
+ __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
+ if constexpr (!has_fusion) {
+ (void) tmp_shared_gate;
+ } else if (!use_gate) {
+ (void) tmp_shared_gate;
+ }
+
+ if (threadIdx.y > 0) {
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+#pragma unroll
+ for (int i = 0; i < rows_per_cuda_block; ++i) {
+ tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
+ }
+ }
+ }
+ }
+ }
+ __syncthreads();
+ if (threadIdx.y > 0) {
+ return;
+ }
+
+ dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
+
+ if constexpr (is_multi_token_id) {
+ dst += token_idx*stride_col_dst;
+ }
+
+ // sum up partial sums and write back result
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+#pragma unroll
+ for (int i = 0; i < rows_per_cuda_block; ++i) {
+#pragma unroll
+ for (int l = 0; l < nwarps-1; ++l) {
+ tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
+ }
+ }
+ }
+ tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
+ if constexpr (has_fusion) {
+ if (use_gate) {
+ tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
+ }
+ }
+ }
+
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
+ float result = tmp[j][threadIdx.x];
+ if constexpr (has_fusion) {
+ if (use_bias) {
+ result += x_biases[j];
+ }
+ if (use_gate) {
+ float gate_value = tmp_gate[j][threadIdx.x];
+ if (use_gate_bias) {
+ gate_value += gate_biases[j];
+ }
+ switch (active_glu) {
+ case GGML_GLU_OP_SWIGLU:
+ result *= ggml_cuda_op_silu_single(gate_value);
+ break;
+ case GGML_GLU_OP_GEGLU:
+ result *= ggml_cuda_op_gelu_single(gate_value);
+ break;
+ case GGML_GLU_OP_SWIGLU_OAI: {
+ result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
+ break;
+ }
+ default:
+ result = result * gate_value;
+ break;
+ }
+ }
+ }
+ dst[j*stride_col_dst + threadIdx.x] = result;
+ }
+ }
+
+ if constexpr (!has_fusion) {
+ GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
+ }
+}
+
+static std::pair<dim3, dim3> calc_launch_params(
+ const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
+ const int warp_size, const mmvq_parameter_table_id table_id) {
+ const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
+ const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
+ const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
+ return {block_nums, block_dims};
+}
+
+template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
+static void mul_mat_vec_q_switch_fusion(
+ const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
+ const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
+ const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
+ const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
+ const uint32_t ids_stride, cudaStream_t stream) {
+
+ const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+ if constexpr (c_ncols_dst == 1) {
+ if (has_fusion) {
+ mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
+ return;
+ }
+ }
+
+ GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
+
+ mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
+}
+
+template <ggml_type type>
+static void mul_mat_vec_q_switch_ncols_dst(
+ const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+ const int ncols_x, const int nrows_x, const int ncols_dst,
+ const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+ const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+ const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ids_stride, cudaStream_t stream) {
+
+ GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
+ GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
+
+ const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
+ const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
+ const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
+
+ const int device = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
+ const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
+
+ const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+ const bool has_ids = ids != nullptr;
+
+ if (has_ids && ncols_dst > 1) {
+ // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+ constexpr int c_ncols_dst = 1;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ return;
+ }
+
+ switch (ncols_dst) {
+ case 1: {
+ constexpr int c_ncols_dst = 1;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ case 2: {
+ constexpr int c_ncols_dst = 2;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ case 3: {
+ constexpr int c_ncols_dst = 3;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ case 4: {
+ constexpr int c_ncols_dst = 4;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ case 5: {
+ constexpr int c_ncols_dst = 5;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ case 6: {
+ constexpr int c_ncols_dst = 6;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ case 7: {
+ constexpr int c_ncols_dst = 7;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ case 8: {
+ constexpr int c_ncols_dst = 8;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+
+ GGML_UNUSED(has_fusion);
+}
+static void mul_mat_vec_q_switch_type(
+ const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
+ const int ncols_x, const int nrows_x, const int ncols_dst,
+ const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+ const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+ const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ids_stride, cudaStream_t stream) {
+ switch (type_x) {
+ case GGML_TYPE_Q4_0:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q4_1:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q5_0:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q5_1:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q8_0:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q2_K:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q3_K:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q4_K:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q5_K:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_Q6_K:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ2_XXS:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ2_XS:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ2_S:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ3_XXS:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ1_S:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ1_M:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ4_NL:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ4_XS:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ case GGML_TYPE_IQ3_S:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
+ (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+void ggml_cuda_mul_mat_vec_q(
+ ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
+ const ggml_cuda_mm_fusion_args_host * fusion) {
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ cudaStream_t stream = ctx.stream();
+
+ const size_t ts_src0 = ggml_type_size(src0->type);
+ const size_t ts_src1 = ggml_type_size(src1->type);
+ const size_t ts_dst = ggml_type_size(dst->type);
+
+ GGML_ASSERT( nb00 == ts_src0);
+ GGML_ASSERT( nb10 == ts_src1);
+ GGML_ASSERT( nb0 == ts_dst);
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+
+ GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
+
+ const float * src1_d = (const float *) src1->data;
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
+ float * dst_d = (float *) dst->data;
+
+ ggml_cuda_mm_fusion_args_device fusion_local{};
+
+ if (fusion) {
+ GGML_ASSERT( !ids || dst->ne[2] == 1);
+ GGML_ASSERT( ids || dst->ne[1] == 1);
+
+ if (fusion->x_bias) {
+ GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
+ GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
+ GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
+ fusion_local.x_bias = fusion->x_bias->data;
+ }
+ if (fusion->gate) {
+ GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
+ fusion_local.gate = fusion->gate->data;
+ }
+ if (fusion->gate_bias) {
+ GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
+ GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
+ GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
+ fusion_local.gate_bias = fusion->gate_bias->data;
+ }
+ fusion_local.glu_op = fusion->glu_op;
+ }
+
+ // If src0 is a temporary compute buffer, clear any potential padding.
+ if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
+ const size_t size_data = ggml_nbytes(src0);
+ const size_t size_alloc = ggml_backend_buffer_get_alloc_size(src0->buffer, src0);
+ if (size_alloc > size_data) {
+ GGML_ASSERT(ggml_is_contiguously_allocated(src0));
+ GGML_ASSERT(!src0->view_src);
+ CUDA_CHECK(cudaMemsetAsync((char *) src0->data + size_data, 0, size_alloc - size_data, stream));
+ }
+ }
+
+ const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+ ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
+ {
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s12 = src1->nb[2] / ts_src1;
+ const int64_t s13 = src1->nb[3] / ts_src1;
+ quantize_row_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+ }
+
+ const int64_t s01 = src0->nb[1] / ts_src0;
+ const int64_t s11 = ne10_padded / QK8_1;
+ const int64_t s1 = dst->nb[1] / ts_dst;
+ const int64_t s02 = src0->nb[2] / ts_src0;
+ const int64_t s2 = dst->nb[2] / ts_dst;
+ const int64_t s03 = src0->nb[3] / ts_src0;
+ const int64_t s3 = dst->nb[3] / ts_dst;
+
+ const int64_t s12 = ne11*s11;
+ const int64_t s13 = ne12*s12;
+
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+ const int64_t ncols_dst = ids ? ne2 : ne1;
+ const int64_t nchannels_y = ids ? ne11 : ne12;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
+ const int64_t stride_col_dst = ids ? s2 : s1;
+ const int64_t stride_col_y = ids ? s12 : s11;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
+ const int64_t stride_channel_y = ids ? s11 : s12;
+
+ const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
+ mul_mat_vec_q_switch_type(
+ src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
+ ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, ids_stride, stream);
+}
+
+void ggml_cuda_op_mul_mat_vec_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ const int64_t ne10 = src1->ne[0];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne0 = dst->ne[0];
+
+ int id = ggml_cuda_get_device();
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the kernel writes into
+ const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+ const int stride_row_x = ne00 / ggml_blck_size(src0->type);
+ const int stride_col_y = src1_padded_row_size / QK8_1;
+
+ ggml_cuda_mm_fusion_args_device fusion_local{};
+ mul_mat_vec_q_switch_type(
+ src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
+
+ GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh b/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh
new file mode 100644
index 0000000..4bb10cf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/mmvq.cuh
@@ -0,0 +1,12 @@
+#include "common.cuh"
+
+#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
+
+void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
+
+void ggml_cuda_op_mul_mat_vec_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream);
diff --git a/llama.cpp/ggml/src/ggml-cuda/norm.cu b/llama.cpp/ggml/src/ggml-cuda/norm.cu
new file mode 100644
index 0000000..ef98f67
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/norm.cu
@@ -0,0 +1,672 @@
+#include "norm.cuh"
+#include <cstdint>
+
+template <int block_size>
+static __global__ void norm_f32(
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+ const int64_t stride_sample, const float eps) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
+
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+ float2 mean_var = make_float2(0.0f, 0.0f);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[col];
+ mean_var.x += xi;
+ mean_var.y += xi * xi;
+ }
+
+ // sum up partial sums
+ extern __shared__ float2 s_sum2[];
+ mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2);
+
+ const float mean = mean_var.x / ncols;
+ const float var = mean_var.y / ncols - mean * mean;
+ const float inv_std = rsqrtf(var + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[col] = (x[col] - mean) * inv_std;
+ }
+}
+
+template <int block_size>
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+ // blockIdx.x: num_groups idx
+ // threadIdx.x: block_size idx
+ const int start = blockIdx.x*group_size + threadIdx.x;
+ const int end = min(blockIdx.x*group_size + group_size, ne_elements);
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += block_size) {
+ tmp += x[j];
+ }
+
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
+
+ const float mean = tmp / group_size;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += block_size) {
+ const float xi = x[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
+
+ const float variance = tmp / group_size;
+ const float scale = rsqrtf(variance + eps);
+ for (int j = start; j < end; j += block_size) {
+ dst[j] *= scale;
+ }
+}
+
+template <int block_size, bool do_multiply = false, bool do_add = false>
+static __global__ void rms_norm_f32(const float * x,
+ float * dst,
+ const int ncols,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const float eps,
+ const float * mul = nullptr,
+ const int64_t mul_stride_row = 0,
+ const int64_t mul_stride_channel = 0,
+ const int64_t mul_stride_sample = 0,
+ const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
+ const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
+ const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
+ const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
+ const float * add = nullptr,
+ const int64_t add_stride_row = 0,
+ const int64_t add_stride_channel = 0,
+ const int64_t add_stride_sample = 0,
+ const uint3 add_ncols_packed = make_uint3(0, 0, 0),
+ const uint3 add_nrows_packed = make_uint3(0, 0, 0),
+ const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
+ const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
+
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
+
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+ if constexpr (do_multiply) {
+ const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
+ const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
+ const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
+ mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
+ }
+
+ if constexpr (do_add) {
+ const int add_row = fastmodulo(row, add_nrows_packed);
+ const int add_channel = fastmodulo(channel, add_nchannels_packed);
+ const int add_sample = fastmodulo(sample, add_nsamples_packed);
+ add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
+
+ const float mean = tmp / ncols;
+ const float scale = rsqrtf(mean + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ if constexpr (do_multiply && do_add) {
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
+ const int add_col = fastmodulo(col, add_ncols_packed);
+ dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
+ } else if constexpr (do_multiply) {
+ const int mul_col = fastmodulo(col, mul_ncols_packed);
+ dst[col] = scale * x[col] * mul[mul_col];
+ } else {
+ dst[col] = scale * x[col];
+ }
+ }
+}
+
+template <int block_size>
+static __global__ void rms_norm_back_f32(
+ const float * grad, const float * xf, float * dst, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ grad += int64_t(row)*ncols;
+ xf += int64_t(row)*ncols;
+ dst += int64_t(row)*ncols;
+
+ float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass
+ float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xfi = xf[col];
+ sum_xx += xfi * xfi;
+ sum_xg += xfi * grad[col];
+ }
+
+ // sum up partial sums
+ sum_xx = warp_reduce_sum(sum_xx);
+ sum_xg = warp_reduce_sum(sum_xg);
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
+ __shared__ float s_sum_xx[32];
+ __shared__ float s_sum_xg[32];
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum_xx[warp_id] = sum_xx;
+ s_sum_xg[warp_id] = sum_xg;
+ }
+ __syncthreads();
+
+ sum_xx = s_sum_xx[lane_id];
+ sum_xx = warp_reduce_sum(sum_xx);
+
+ sum_xg = s_sum_xg[lane_id];
+ sum_xg = warp_reduce_sum(sum_xg);
+ }
+
+ const float mean_eps = sum_xx / ncols + eps;
+ const float sum_eps = sum_xx + ncols*eps;
+
+ const float scale_grad = rsqrtf(mean_eps);
+ const float scale_x = -scale_grad * sum_xg/sum_eps;
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[col] = scale_grad*grad[col] + scale_x*xf[col];
+ }
+}
+
+// template <int block_size>
+// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
+// const int row = blockIdx.x*blockDim.y + threadIdx.y;
+// const int tid = threadIdx.x;
+
+// float tmp = 0.0f; // partial sum for thread in warp
+
+// for (int col = tid; col < ncols; col += block_size) {
+// const float xi = x[row*ncols + col];
+// tmp += xi * xi;
+// }
+
+// // sum up partial sums
+// tmp = warp_reduce_sum(tmp);
+// if (block_size > WARP_SIZE) {
+// __shared__ float s_sum[32];
+// int warp_id = threadIdx.x / WARP_SIZE;
+// int lane_id = threadIdx.x % WARP_SIZE;
+// if (lane_id == 0) {
+// s_sum[warp_id] = tmp;
+// }
+// __syncthreads();
+// tmp = s_sum[lane_id];
+// tmp = warp_reduce_sum(tmp);
+// }
+
+// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+// for (int col = tid; col < ncols; col += block_size) {
+// dst[row*ncols + col] = scale * x[row*ncols + col];
+// }
+// }
+
+template <int block_size>
+static __global__ void l2_norm_f32(
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+ const int64_t stride_sample, const float eps) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
+
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ extern __shared__ float s_sum[];
+ tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum);
+
+ // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
+ const float scale = rsqrtf(fmaxf(tmp, eps * eps));
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[col] = scale * x[col];
+ }
+}
+
+static void norm_f32_cuda(
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ }
+}
+
+static void group_norm_f32_cuda(
+ const float * x, float * dst, const int num_groups, const float eps, const int group_size, const int ne_elements, cudaStream_t stream) {
+ if (group_size < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ }
+}
+
+static void rms_norm_f32_cuda(
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(256, 1, 1);
+ rms_norm_f32<256, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, false><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ }
+}
+
+static void rms_norm_mul_f32_cuda(const float * x,
+ const float * mul,
+ const float * add,
+ float * dst,
+ const int ncols,
+ const int nrows,
+ const int nchannels,
+ const int nsamples,
+ const int64_t stride_row,
+ const int64_t stride_channel,
+ const int64_t stride_sample,
+ const int64_t mul_stride_row,
+ const int64_t mul_stride_channel,
+ const int64_t mul_stride_sample,
+ const uint32_t mul_ncols,
+ const uint32_t mul_nrows,
+ const uint32_t mul_nchannels,
+ const uint32_t mul_nsamples,
+ const int64_t add_stride_row,
+ const int64_t add_stride_channel,
+ const int64_t add_stride_sample,
+ const uint32_t add_ncols,
+ const uint32_t add_nrows,
+ const uint32_t add_nchannels,
+ const uint32_t add_nsamples,
+ const float eps,
+ cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (mul == nullptr) {
+ rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
+ return;
+ }
+ if (add == nullptr) {
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(256, 1, 1);
+ rms_norm_f32<256, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
+ }
+ } else {
+ const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
+ const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
+ const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
+ const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
+
+ const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
+ const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
+ const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
+ const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(256, 1, 1);
+ rms_norm_f32<256, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
+ add_nchannels_packed, add_nsamples_packed);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32<1024, true, true><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(
+ x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
+ mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
+ add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
+ add_nchannels_packed, add_nsamples_packed);
+ }
+ }
+}
+
+static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_back_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_back_f32<1024><<<nrows, block_dims, 0, stream>>>(grad, xf, dst, ncols, eps);
+ }
+}
+
+static void l2_norm_f32_cuda(
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ l2_norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ }
+}
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
+
+ const size_t ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(nb00 == ts0);
+ const int64_t s01 = nb01 / ts0;
+ const int64_t s02 = nb02 / ts0;
+ const int64_t s03 = nb03 / ts0;
+
+ norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int num_groups = dst->op_params[0];
+
+ float eps;
+ memcpy(&eps, dst->op_params + 1, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
+
+ int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+ group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], eps, group_size, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
+
+ const size_t ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(nb00 == ts0);
+ const int64_t s01 = nb01 / ts0;
+ const int64_t s02 = nb02 / ts0;
+ const int64_t s03 = nb03 / ts0;
+
+ rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) {
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+ float eps = 0.0f;
+
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float * src0_d = (const float *) rms_norm_src->data;
+ const float * mul_d = nullptr;
+ const ggml_tensor * mul_src = nullptr;
+
+ if (mul_tensor->src[0] == dst) {
+ mul_d = (float *) mul_tensor->src[1]->data;
+ mul_src = mul_tensor->src[1];
+ } else if(mul_tensor->src[1] == dst) {
+ mul_d = (float *) mul_tensor->src[0]->data;
+ mul_src = mul_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ float * dst_d = (float *) mul_tensor->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(eps >= 0.0f);
+
+ const int64_t ne00 = rms_norm_src->ne[0];
+ const int64_t ne01 = rms_norm_src->ne[1];
+ const int64_t ne02 = rms_norm_src->ne[2];
+ const int64_t ne03 = rms_norm_src->ne[3];
+
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+ const size_t ts_mul = ggml_type_size(mul_src->type);
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+ const int mul_ncols = mul_src->ne[0];
+ const int mul_nrows = mul_src->ne[1];
+ const int mul_nchannels = mul_src->ne[2];
+ const int mul_nsamples = mul_src->ne[3];
+
+ rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
+ ne00, ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ 0, 0, 0,
+ 0, 0, 0, 0,
+ eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor) {
+ const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
+ float eps = 0.0f;
+
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float * src0_d = (const float *) rms_norm_src->data;
+ const float * mul_d = nullptr;
+ const ggml_tensor * mul_src = nullptr;
+
+ if (mul_tensor->src[0] == dst) {
+ mul_d = (float *) mul_tensor->src[1]->data;
+ mul_src = mul_tensor->src[1];
+ } else if (mul_tensor->src[1] == dst) {
+ mul_d = (float *) mul_tensor->src[0]->data;
+ mul_src = mul_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ const float * add_d = nullptr;
+ const ggml_tensor * add_src = nullptr;
+
+ if (add_tensor->src[0] == mul_tensor) {
+ add_d = (float *) add_tensor->src[1]->data;
+ add_src = add_tensor->src[1];
+ } else if (add_tensor->src[1] == mul_tensor) {
+ add_d = (float *) add_tensor->src[0]->data;
+ add_src = add_tensor->src[0];
+ } else {
+ GGML_ASSERT(false);
+ }
+
+ float * dst_d = (float *) add_tensor->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
+ GGML_ASSERT(eps >= 0.0f);
+
+ const int64_t ne00 = rms_norm_src->ne[0];
+ const int64_t ne01 = rms_norm_src->ne[1];
+ const int64_t ne02 = rms_norm_src->ne[2];
+ const int64_t ne03 = rms_norm_src->ne[3];
+
+ const size_t ts0 = ggml_type_size(rms_norm_src->type);
+ GGML_ASSERT(rms_norm_src->nb[0] == ts0);
+ const int64_t s01 = rms_norm_src->nb[1] / ts0;
+ const int64_t s02 = rms_norm_src->nb[2] / ts0;
+ const int64_t s03 = rms_norm_src->nb[3] / ts0;
+
+ const size_t ts_mul = ggml_type_size(mul_src->type);
+ GGML_ASSERT(mul_src->nb[0] == ts_mul);
+ const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
+ const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
+ const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
+
+ const int mul_ncols = mul_src->ne[0];
+ const int mul_nrows = mul_src->ne[1];
+ const int mul_nchannels = mul_src->ne[2];
+ const int mul_nsamples = mul_src->ne[3];
+
+ const size_t ts_add = ggml_type_size(add_src->type);
+ GGML_ASSERT(add_src->nb[0] == ts_add);
+ const int64_t add_s01 = add_src->nb[1] / ts_add;
+ const int64_t add_s02 = add_src->nb[2] / ts_add;
+ const int64_t add_s03 = add_src->nb[3] / ts_add;
+
+ const int add_ncols = add_src->ne[0];
+ const int add_nrows = add_src->ne[1];
+ const int add_nchannels = add_src->ne[2];
+ const int add_nsamples = add_src->ne[3];
+
+ rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
+ ne00,ne01, ne02, ne03,
+ /*s00*/ s01, s02, s03,
+ /*mul_s00*/ mul_s01, mul_s02, mul_s03,
+ mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
+ /*add_s00*/ add_s01, add_s02, add_s03,
+ add_ncols, add_nrows, add_nchannels, add_nsamples,
+ eps, stream);
+}
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * grad = dst->src[0]; // gradients
+ const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass
+
+ const float * grad_d = (const float *) grad->data;
+ const float * src0f_d = (const float *) src0f->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(grad));
+
+ GGML_ASSERT( grad->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0f->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ne00 = src0f->ne[0];
+ const int64_t nrows = ggml_nrows(src0f);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
+
+ rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
+}
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+ GGML_ASSERT(eps >= 0.0f);
+
+ const size_t ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(nb00 == ts0);
+ const int64_t s01 = nb01 / ts0;
+ const int64_t s02 = nb02 / ts0;
+ const int64_t s03 = nb03 / ts0;
+
+ l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/norm.cuh b/llama.cpp/ggml/src/ggml-cuda/norm.cuh
new file mode 100644
index 0000000..a74f637
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/norm.cuh
@@ -0,0 +1,18 @@
+#include "common.cuh"
+
+void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
+
+void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ ggml_tensor * mul_tensor,
+ ggml_tensor * add_tensor);
+
+void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cu b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cu
new file mode 100644
index 0000000..35154f2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cu
@@ -0,0 +1,78 @@
+#include "ggml-impl.h"
+#include "opt-step-adamw.cuh"
+
+#include <cstdint>
+
+static __global__ void opt_step_adamw_f32(
+ float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v,
+ const float * __restrict__ pars, const int64_t k) {
+
+ const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const float alpha = pars[0];
+ const float beta1 = pars[1];
+ const float beta2 = pars[2];
+ const float eps = pars[3];
+ const float wd = pars[4];
+ const float beta1h = pars[5];
+ const float beta2h = pars[6];
+
+ const float gi = g[i];
+ const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
+ const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2);
+
+ g_m[i] = gmi;
+ g_v[i] = gvi;
+
+ const float mh = gmi*beta1h;
+ const float vh = sqrtf(gvi*beta2h) + eps;
+
+ x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
+}
+
+static void opt_step_adamw_f32_cuda(
+ float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) {
+
+ const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+ const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1);
+ opt_step_adamw_f32<<<block_nums, block_dims, 0, stream>>>(x, g, g_m, g_v, pars, k);
+}
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src0_grad = dst->src[1];
+ const ggml_tensor * src0_grad_m = dst->src[2];
+ const ggml_tensor * src0_grad_v = dst->src[3];
+ const ggml_tensor * adamw_params = dst->src[4];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32);
+ GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src0_grad));
+ GGML_ASSERT(ggml_is_contiguous(src0_grad_m));
+ GGML_ASSERT(ggml_is_contiguous(src0_grad_v));
+ GGML_ASSERT(ggml_is_contiguous(adamw_params));
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
+ GGML_ASSERT(ggml_nelements(adamw_params) == 7);
+
+ float * src0_d = (float *) src0->data;
+ const float * src0_grad_d = (const float *) src0_grad->data;
+ float * src0_grad_m_d = (float *) src0_grad_m->data;
+ float * src0_grad_v_d = (float *) src0_grad_v->data;
+ const float * adamw_params_d = (const float *) adamw_params->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ const int64_t ne = ggml_nelements(src0);
+
+ opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cuh b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cuh
new file mode 100644
index 0000000..58d6f6e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-adamw.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu
new file mode 100644
index 0000000..460b16d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cu
@@ -0,0 +1,49 @@
+#include "ggml-impl.h"
+#include "opt-step-sgd.cuh"
+
+#include <cstdint>
+
+static __global__ void opt_step_sgd_f32(
+ float * __restrict__ x, const float * __restrict__ g,
+ const float * __restrict__ pars, const int64_t k) {
+
+ const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ x[i] = x[i] * (1.0f - pars[0] * pars[1]) - pars[0] * g[i];
+}
+
+static void opt_step_sgd_f32_cuda(
+ float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
+
+ const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
+ const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
+ opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
+}
+
+void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src0_grad = dst->src[1];
+ const ggml_tensor * params = dst->src[2];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
+ GGML_ASSERT(params->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src0_grad));
+ GGML_ASSERT(ggml_is_contiguous(params));
+ GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
+ GGML_ASSERT(ggml_nelements(params) == 2);
+
+ float * src0_d = (float *) src0->data;
+ const float * src0_grad_d = (const float *) src0_grad->data;
+ const float * params_d = (const float *) params->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ const int64_t ne = ggml_nelements(src0);
+
+ opt_step_sgd_f32_cuda(src0_d, src0_grad_d, params_d, ne, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh
new file mode 100644
index 0000000..f97ab7d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/opt-step-sgd.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
+
+void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/out-prod.cu b/llama.cpp/ggml/src/ggml-cuda/out-prod.cu
new file mode 100644
index 0000000..c9b2b69
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/out-prod.cu
@@ -0,0 +1,68 @@
+#include "out-prod.cuh"
+
+#include <cstdint>
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(ne01 == ne11);
+ GGML_ASSERT(ne0 == ne00);
+ GGML_ASSERT(ne1 == ne10);
+
+ GGML_ASSERT(ne2 % src0->ne[2] == 0);
+ GGML_ASSERT(ne3 % src0->ne[3] == 0);
+
+ GGML_ASSERT(ne2 == src1->ne[2]);
+ GGML_ASSERT(ne3 == src1->ne[3]);
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+ cublasHandle_t handle = ctx.cublas_handle();
+
+ const float alpha = 1.0f;
+ const float beta = 0.0f;
+
+ CUBLAS_CHECK(cublasSetStream(handle, stream));
+
+ const int64_t lda = nb01 / sizeof(float);
+ const int64_t ldc = nb1 / sizeof(float);
+
+ const bool src1_T = ggml_is_transposed(src1);
+ const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
+ const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
+ GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float));
+
+ // data strides in dimensions 2/3
+ const size_t s02 = nb02 / sizeof(float);
+ const size_t s03 = nb03 / sizeof(float);
+ const size_t s12 = nb12 / sizeof(float);
+ const size_t s13 = nb13 / sizeof(float);
+ const size_t s2 = nb2 / sizeof(float);
+ const size_t s3 = nb3 / sizeof(float);
+
+ // dps == dst per src0, used for group query attention
+ const int64_t dps2 = ne2 / ne02;
+ const int64_t dps3 = ne3 / ne03;
+
+ // TODO batched matrix multiplication
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
+ CUBLAS_CHECK(
+ cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
+ ne0, ne1, ne01,
+ &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
+ src1_d + i3 *s13 + i2 *s12, ldb,
+ &beta, dst_d + i3 *s3 + i2 *s2, ldc));
+ }
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/out-prod.cuh b/llama.cpp/ggml/src/ggml-cuda/out-prod.cuh
new file mode 100644
index 0000000..a0046f5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/out-prod.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/pad.cu b/llama.cpp/ggml/src/ggml-cuda/pad.cu
new file mode 100644
index 0000000..31cd00f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/pad.cu
@@ -0,0 +1,106 @@
+#include "pad.cuh"
+
+#include <stdint.h>
+
+__device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
+ // + size ensures negatives are handled properly
+ return (coord + size) % size;
+}
+
+static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
+ const int lp0, const int rp0, const int lp1, const int rp1,
+ const int lp2, const int rp2, const int lp3, const int rp3,
+ const int ne0, const int ne1, const int ne2, const int ne3,
+ const bool circular) {
+ // blockIdx.z: i3*ne2+i2
+ // blockIdx.y: i1
+ // blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
+ // gridDim.y: ne1
+ int i0 = threadIdx.x + blockIdx.x * blockDim.x;
+ int i1 = blockIdx.y;
+ int i2 = blockIdx.z % ne2;
+ int i3 = blockIdx.z / ne2;
+
+ if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+ return;
+ }
+
+ const int64_t dst_idx = i3 * (ne0 * ne1 * ne2) + i2 * (ne0 * ne1) + i1 * ne0 + i0;
+
+ if (!circular) {
+ if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && (i2 >= lp2 && i2 < ne2 - rp2) &&
+ (i3 >= lp3 && i3 < ne3 - rp3)) {
+ const int64_t i00 = i0 - lp0;
+ const int64_t i01 = i1 - lp1;
+ const int64_t i02 = i2 - lp2;
+ const int64_t i03 = i3 - lp3;
+
+ const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
+
+ dst[dst_idx] = src[src_idx];
+ } else {
+ dst[dst_idx] = 0.0f;
+ }
+ }
+ // circular means on a torus, so x and y wrap around
+ else {
+ const int64_t ne00 = ne0 - lp0 - rp0;
+ const int64_t ne01 = ne1 - lp1 - rp1;
+ const int64_t ne02 = ne2 - lp2 - rp2;
+ const int64_t ne03 = ne3 - lp3 - rp3;
+
+ const int64_t i00 = wrap_around(i0 - lp0, ne00);
+ const int64_t i01 = wrap_around(i1 - lp1, ne01);
+ const int64_t i02 = wrap_around(i2 - lp2, ne02);
+ const int64_t i03 = wrap_around(i3 - lp3, ne03);
+
+ const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
+
+ dst[dst_idx] = src[src_idx];
+ }
+}
+
+
+static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
+ const int lp0, const int rp0, const int lp1, const int rp1,
+ const int lp2, const int rp2, const int lp3, const int rp3,
+ const int ne0, const int ne1, const int ne2, const int ne3,
+ const bool circular, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne1, ne2 * ne3);
+ pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst,
+ lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
+ ne0, ne1, ne2, ne3, circular);
+}
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
+ const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
+ const int32_t lp1 = ((const int32_t *) (dst->op_params))[2];
+ const int32_t rp1 = ((const int32_t *) (dst->op_params))[3];
+ const int32_t lp2 = ((const int32_t *) (dst->op_params))[4];
+ const int32_t rp2 = ((const int32_t *) (dst->op_params))[5];
+ const int32_t lp3 = ((const int32_t *) (dst->op_params))[6];
+ const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
+ const int32_t circular = ((const int32_t *) (dst->op_params))[8];
+
+ const size_t s00 = nb00 / ggml_type_size(src0->type);
+ const size_t s01 = nb01 / ggml_type_size(src0->type);
+ const size_t s02 = nb02 / ggml_type_size(src0->type);
+ const size_t s03 = nb03 / ggml_type_size(src0->type);
+
+ pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,
+ lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
+ dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ (bool) circular, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/pad.cuh b/llama.cpp/ggml/src/ggml-cuda/pad.cuh
new file mode 100644
index 0000000..8fd386b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/pad.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_PAD_BLOCK_SIZE 256
+
+void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu
new file mode 100644
index 0000000..32993eb
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cu
@@ -0,0 +1,91 @@
+#include "pad_reflect_1d.cuh"
+
+static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
+ pad_reflect_1d_kernel_f32(
+ const void * __restrict__ src0,
+ void * __restrict__ dst,
+ const int64_t ne0,
+ const int64_t ne00,
+ const uint3 ne01,
+ const int64_t ne02,
+ const int64_t ne03,
+ const int64_t nb00,
+ const int64_t nb01,
+ const int64_t nb02,
+ const int64_t nb03,
+ const int64_t nb0,
+ const int64_t nb1,
+ const int64_t nb2,
+ const int64_t nb3,
+ const int p0,
+ const int p1) {
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+
+ const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
+ const int64_t tile1 = div_mod_packed.y; // i1
+ const int64_t tile0 = div_mod_packed.x; // nth i0 tile
+ const int64_t i1 = tile1;
+ const int64_t i0 = threadIdx.x + tile0 * blockDim.x;
+
+ // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
+ if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
+ return;
+ }
+
+ const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
+ char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
+
+ const int64_t rel_i0 = i0 - p0; // relative i0 in src0
+ int64_t src_idx;
+
+ if (rel_i0 < 0) {
+ // Left padding - reflect
+ src_idx = -rel_i0;
+ } else if (rel_i0 < ne00) {
+ // Middle - copy
+ src_idx = rel_i0;
+ } else {
+ // Right padding - reflect
+ src_idx = 2 * ne00 - 2 - rel_i0;
+ }
+ const float value = *(const float *) (src0_ptr + src_idx * nb00);
+ *(float *) (dst_ptr + i0 * nb0) = value;
+
+ GGML_UNUSED(p1);
+}
+
+void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *) dst->op_params;
+ const int p0 = opts[0];
+ const int p1 = opts[1];
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t ne01 = src0->ne[1];
+ const uint3 ne01_packed = init_fastdiv_values(ne01);
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
+
+ const int64_t ne0 = dst->ne[0];
+
+ // sanity: padded length matches
+ GGML_ASSERT(ne0 == ne00 + p0 + p1);
+
+ constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x)
+ const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0
+ // grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
+ // grid.y covers i2: [ne02]
+ // grid.z covers i3: [ne03]
+ const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
+ const dim3 block_dims((unsigned) bx, 1, 1);
+
+ pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
+ src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh
new file mode 100644
index 0000000..15f2ed1
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/pad_reflect_1d.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
+
+void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/pool2d.cu b/llama.cpp/ggml/src/ggml-cuda/pool2d.cu
new file mode 100644
index 0000000..c6d51e4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/pool2d.cu
@@ -0,0 +1,94 @@
+#include "pool2d.cuh"
+
+template <typename Ti, typename To>
+static __global__ void pool2d_nchw_kernel(
+ const int ih, const int iw, const int oh, const int ow,
+ const int kh, const int kw, const int sh, const int sw,
+ const int ph, const int pw, const int parallel_elements,
+ const Ti* src, To* dst, const enum ggml_op_pool op) {
+ int idx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (idx >= parallel_elements) {
+ return;
+ }
+
+ const int I_HW = ih * iw;
+ const int O_HW = oh * ow;
+ const int nc = idx / O_HW;
+ const int cur_oh = idx % O_HW / ow;
+ const int cur_ow = idx % O_HW % ow;
+ const Ti* i_ptr = src + nc * I_HW;
+ To* o_ptr = dst + nc * O_HW;
+ const int start_h = cur_oh * sh - ph;
+ const int bh = max(0, start_h);
+ const int eh = min(ih, start_h + kh);
+ const int start_w = cur_ow * sw - pw;
+ const int bw = max(0, start_w);
+ const int ew = min(iw, start_w + kw);
+ const To scale = 1. / (kh * kw);
+ To res = 0;
+
+ switch (op) {
+ case GGML_OP_POOL_AVG: res = 0; break;
+ case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+ default: assert(false);
+ }
+
+ for (int i = bh; i < eh; i += 1) {
+ for (int j = bw; j < ew; j += 1) {
+#if __CUDA_ARCH__ >= 350
+ Ti cur = __ldg(i_ptr + i * iw + j);
+#else
+ Ti cur = i_ptr[i * iw + j];
+#endif
+ switch (op) {
+ case GGML_OP_POOL_AVG: res += cur * scale; break;
+ case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
+ default: assert(false);
+ }
+ }
+ }
+ o_ptr[cur_oh * ow + cur_ow] = res;
+}
+
+static void pool2d_nchw_kernel_f32_f32_cuda(
+ const int ih, const int iw, const int oh, const int ow,
+ const int kh, const int kw, const int sh, const int sw,
+ const int ph, const int pw, const int parallel_elements,
+ const float * src, float * dst, const enum ggml_op_pool op,
+ cudaStream_t stream) {
+
+ const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
+ dim3 block_nums(num_blocks);
+ pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
+}
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int32_t * opts = (const int32_t *)dst->op_params;
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+ const int k0 = opts[1];
+ const int k1 = opts[2];
+ const int s0 = opts[3];
+ const int s1 = opts[4];
+ const int p0 = opts[5];
+ const int p1 = opts[6];
+
+ const int64_t IH = src0->ne[1];
+ const int64_t IW = src0->ne[0];
+
+ const int64_t N = dst->ne[3];
+ const int64_t OC = dst->ne[2];
+ const int64_t OH = dst->ne[1];
+ const int64_t OW = dst->ne[0];
+
+ const int parallel_elements = N * OC * OH * OW;
+
+ pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/pool2d.cuh b/llama.cpp/ggml/src/ggml-cuda/pool2d.cuh
new file mode 100644
index 0000000..7841292
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/pool2d.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_POOL2D_BLOCK_SIZE 256
+
+void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/quantize.cu b/llama.cpp/ggml/src/ggml-cuda/quantize.cu
new file mode 100644
index 0000000..a8c68e4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/quantize.cu
@@ -0,0 +1,343 @@
+#include "quantize.cuh"
+#include <cstdint>
+
+__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
+static __global__ void quantize_q8_1(
+ const float * __restrict__ x, void * __restrict__ vy,
+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
+ const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int64_t i3 = fastdiv(blockIdx.z, ne2);
+ const int64_t i2 = blockIdx.z - i3*ne2.z;
+ const int64_t i1 = blockIdx.y;
+
+ const int64_t & i00 = i0;
+ const int64_t & i01 = i1;
+ const int64_t & i02 = i2;
+ const int64_t & i03 = i3;
+
+ const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
+
+ block_q8_1 * y = (block_q8_1 *) vy;
+
+ const int64_t ib = i_cont / QK8_1; // block index
+ const int64_t iqs = i_cont % QK8_1; // quant index
+
+ const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
+ float amax = fabsf(xi);
+ float sum = xi;
+
+ amax = warp_reduce_max<QK8_1>(amax);
+ sum = warp_reduce_sum<QK8_1>(sum);
+
+ const float d = amax / 127.0f;
+ const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
+
+ y[ib].qs[iqs] = q;
+
+ if (iqs > 0) {
+ return;
+ }
+
+ y[ib].ds = make_half2(d, sum);
+}
+
+__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
+ if (!(amax > 0.0f)) {
+ return 0;
+ }
+
+ // FP4 E2M1: max exponent (unbiased) is 2.
+ constexpr int FP4_E2M1_EMAX = 2;
+
+ const float e = log2f(amax);
+
+ // "even" -> round-to-nearest integer, ties-to-even
+ const int e_int = __float2int_rn(e);
+
+ const int shared_exp = e_int - FP4_E2M1_EMAX;
+
+ int biased = shared_exp + 127;
+
+ biased = max(biased, 0);
+ biased = min(biased, 254);
+
+ return static_cast<uint8_t>(biased);
+}
+
+// quantize values in the format mxfp4 is stored which is interleaved nibbles
+// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
+static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
+ const int32_t * __restrict__ ids,
+ void * __restrict__ vy,
+ const int64_t ne00,
+ const int64_t s01,
+ const int64_t s02,
+ const int64_t s03,
+ const int64_t ne0,
+ const int ne1,
+ const int ne2) {
+ constexpr int vals_per_scale = 32;
+ constexpr int vals_per_warp = 2 * vals_per_scale; // Each warp processes 2 blocks of 32 = 64 values
+
+ const int warp_id = threadIdx.y;
+ const int lane_id_32 = threadIdx.x;
+
+ const int nwarps = blockDim.y;
+
+ const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
+
+ if (warp_start_offset >= ne0) {
+ return;
+ }
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.z % ne2;
+ const int64_t i3 = blockIdx.z / ne2;
+
+ const int64_t i01 = ids ? ids[i1] : i1;
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
+
+ block_fp4_mmq * y = (block_fp4_mmq *) vy;
+
+ const int64_t block_fp4_mmq_size = 8 * QK_MXFP4; // 256 values
+ const int64_t ib0 = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
+ const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
+ const int64_t quad_idx_in_block = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
+
+ const int group_id = lane_id_32 / 4;
+ const int lane_in_group = lane_id_32 % 4;
+ const int base = group_id * 2;
+ char2 * yqs2 = (char2 *) y[ib].qs;
+
+ const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
+
+ uint8_t scales[2];
+
+#pragma unroll
+ for (int b = 0; b < 2; ++b) {
+ const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
+ const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
+
+ float amax = fabsf(xi);
+#pragma unroll
+ for (int mask = 16; mask > 0; mask >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+ }
+
+ const uint8_t e = compute_e8m0_scale(amax);
+ scales[b] = e;
+ const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
+
+#if CUDART_VERSION >= 12080
+ const float scaled_val = xi * inv_s;
+
+ const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
+ const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
+ const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
+ const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
+
+ if (lane_in_group == 0) {
+ __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
+
+ yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
+ }
+#else
+ // Fallback: manual FP4 conversion using LUT
+ const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
+
+ const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base, WARP_SIZE);
+ const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1, WARP_SIZE);
+ const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
+ const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
+
+ if (lane_in_group == 0) {
+ char2 q;
+ q.x = (q_hi_0 << 4) | q_lo_0;
+ q.y = (q_hi_1 << 4) | q_lo_1;
+ yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
+ }
+#endif // CUDART_VERSION >= 12080
+ }
+
+ if (lane_id_32 == 0) {
+ // Store 2 scales packed into 1 uint32
+ y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
+ }
+}
+
+template <mmq_q8_1_ds_layout ds_layout>
+static __global__ void quantize_mmq_q8_1(
+ const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t ne0, const int ne1, const int ne2) {
+
+ constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
+ constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
+
+ const int64_t i0 = ((int64_t)blockDim.x*blockIdx.y + threadIdx.x)*4;
+
+ if (i0 >= ne0) {
+ return;
+ }
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.z % ne2;
+ const int64_t i3 = blockIdx.z / ne2;
+
+ const int64_t i00 = i0;
+ const int64_t i01 = ids ? ids[i1] : i1;
+ const int64_t i02 = i2;
+ const int64_t i03 = i3;
+
+ const float4 * x4 = (const float4 *) x;
+
+ block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
+
+ const int64_t ib0 = blockIdx.z*((int64_t)gridDim.x*gridDim.y*blockDim.x/QK8_1); // first block of channel
+ const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.x; // block index in channel
+ const int64_t iqs = i0 % (4*QK8_1); // quant index in block
+
+ // Load 4 floats per thread and calculate max. abs. value between them:
+ const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
+ float amax = fabsf(xi.x);
+ amax = fmaxf(amax, fabsf(xi.y));
+ amax = fmaxf(amax, fabsf(xi.z));
+ amax = fmaxf(amax, fabsf(xi.w));
+
+ // Exchange max. abs. value between vals_per_scale/4 threads.
+#pragma unroll
+ for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
+ amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
+ }
+
+ float sum;
+ if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
+ sum = xi.x + xi.y + xi.z + xi.w;
+
+ // Calculate sums across vals_per_sum/4 threads.
+#pragma unroll
+ for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
+ sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
+ }
+ }
+
+ const float d_inv = 127.0f / amax;
+ char4 q;
+ q.x = roundf(xi.x*d_inv);
+ q.y = roundf(xi.y*d_inv);
+ q.z = roundf(xi.z*d_inv);
+ q.w = roundf(xi.w*d_inv);
+
+ // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
+ char4 * yqs4 = (char4 *) y[ib].qs;
+ yqs4[iqs/4] = q;
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
+ if (iqs % 16 != 0 || iqs >= 96) {
+ return;
+ }
+
+ y[ib].d2s6[2 + iqs/16] = sum;
+
+ if (iqs % 64 != 0) {
+ return;
+ }
+
+ const float d = 1.0f / d_inv;
+
+ y[ib].d2s6[iqs/64] = d;
+
+ return;
+ }
+
+ if (iqs % 32 != 0) {
+ return;
+ }
+
+ const float d = 1.0f / d_inv;
+
+ if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
+ y[ib].ds4[iqs/32] = make_half2(d, sum);
+ } else {
+ y[ib].d4[iqs/32] = d;
+ }
+}
+
+void quantize_row_q8_1_cuda(
+ const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
+ GGML_ASSERT(!ids);
+ GGML_ASSERT(ne0 % QK8_1 == 0);
+
+ const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
+
+ const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
+ const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
+ quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
+ GGML_UNUSED(type_src0);
+}
+
+void quantize_mmq_q8_1_cuda(
+ const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
+ const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
+ const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ne0 % (4*QK8_1) == 0);
+
+ // ne1 tends to assume the highest values, therefore use it as the "x" dimension of the CUDA grid:
+ const int64_t block_num_y = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
+ const dim3 num_blocks(ne1, block_num_y, ne2*ne3);
+ const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
+ switch (mmq_get_q8_1_ds_layout(type_src0)) {
+ case MMQ_Q8_1_DS_LAYOUT_D4:
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
+ <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+ break;
+ case MMQ_Q8_1_DS_LAYOUT_DS4:
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
+ <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+ break;
+ case MMQ_Q8_1_DS_LAYOUT_D2S6:
+ quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
+ <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
+void quantize_mmq_mxfp4_cuda(const float * x,
+ const int32_t * ids,
+ void * vy,
+ [[maybe_unused]] const ggml_type type_src0,
+ const int64_t ne00,
+ const int64_t s01,
+ const int64_t s02,
+ const int64_t s03,
+ const int64_t ne0,
+ const int64_t ne1,
+ const int64_t ne2,
+ const int64_t ne3,
+ cudaStream_t stream) {
+ GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
+
+ constexpr int nwarps = 8;
+ constexpr int vals_per_warp = 2 * QK_MXFP4;
+ constexpr int vals_per_block = nwarps * vals_per_warp;
+
+ const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
+ const dim3 num_blocks(ne1, block_num_y, ne2 * ne3);
+ const dim3 block_size(WARP_SIZE, nwarps, 1);
+
+ quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/quantize.cuh b/llama.cpp/ggml/src/ggml-cuda/quantize.cuh
new file mode 100644
index 0000000..6a91df6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/quantize.cuh
@@ -0,0 +1,41 @@
+#pragma once
+
+#include "common.cuh"
+#include "mmq.cuh"
+
+#include <cstdint>
+
+#define CUDA_QUANTIZE_BLOCK_SIZE 256
+#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
+
+static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
+static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
+
+typedef void (*quantize_cuda_t)(
+ const float * x, const int32_t * ids, void * vy,
+ ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
+
+void quantize_row_q8_1_cuda(
+ const float * x, const int32_t * ids, void * vy,
+ ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
+
+void quantize_mmq_q8_1_cuda(
+ const float * x, const int32_t * ids, void * vy,
+ ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
+ int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
+
+void quantize_mmq_mxfp4_cuda(const float * x,
+ const int32_t * ids,
+ void * vy,
+ ggml_type type_src0,
+ int64_t ne00,
+ int64_t s01,
+ int64_t s02,
+ int64_t s03,
+ int64_t ne0,
+ int64_t ne1,
+ int64_t ne2,
+ int64_t ne3,
+ cudaStream_t stream);
diff --git a/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh b/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh
new file mode 100644
index 0000000..de240fd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/reduce_rows.cuh
@@ -0,0 +1,39 @@
+#include "common.cuh"
+
+// Row reduction kernel template - compute sum (norm=false) or mean (norm=true)
+template <bool norm>
+static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) {
+ const int row = blockIdx.x;
+ const int col = threadIdx.x;
+
+ float sum = 0.0f;
+ const int num_unroll = 8;
+ float temp[num_unroll];
+ float sum_temp[num_unroll] = { 0.0f };
+ for (int i = col; i < ncols;) {
+ for (int j = 0; j < num_unroll; ++j) {
+ if (i < ncols) {
+ temp[j] = x[row * ncols + i];
+ } else {
+ temp[j] = 0;
+ }
+ i += blockDim.x;
+ }
+ for (int j = 0; j < num_unroll; ++j) {
+ sum_temp[j] += temp[j];
+ }
+ }
+ for (int j = 0; j < num_unroll; ++j) {
+ sum += sum_temp[j];
+ }
+
+ // sum up partial sums
+ __shared__ float shared_vals[32];
+ sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals);
+
+ if (col != 0) {
+ return;
+ }
+
+ dst[row] = norm ? sum / ncols : sum;
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/roll.cu b/llama.cpp/ggml/src/ggml-cuda/roll.cu
new file mode 100644
index 0000000..a339dfc
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/roll.cu
@@ -0,0 +1,67 @@
+#include "ggml-cuda/common.cuh"
+#include "roll.cuh"
+
+static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) {
+ if (idx < 0) {
+ return idx + ne;
+ }
+ if (idx >= ne) {
+ return idx - ne;
+ }
+ return idx;
+}
+
+static __global__ void roll_f32_cuda(const float * __restrict__ src,
+ float * __restrict__ dst,
+ const int64_t ne00,
+ const int64_t ne01,
+ const int64_t ne02,
+ const int64_t ne03,
+ const int s0,
+ const int s1,
+ const int s2,
+ const int s3) {
+ const int64_t idx = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+ const int64_t n_elements = ne00 * ne01 * ne02 * ne03;
+
+ if (idx >= n_elements) {
+ return;
+ }
+
+ const int64_t i0 = idx % ne00;
+ const int64_t i1 = (idx / ne00) % ne01;
+ const int64_t i2 = (idx / (ne00 * ne01)) % ne02;
+ const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03;
+
+ const int64_t d0 = wrap_index(i0 - s0, ne00);
+ const int64_t d1 = wrap_index(i1 - s1, ne01);
+ const int64_t d2 = wrap_index(i2 - s2, ne02);
+ const int64_t d3 = wrap_index(i3 - s3, ne03);
+
+ dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] =
+ src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0];
+}
+
+void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ int s0 = dst->op_params[0];
+ int s1 = dst->op_params[1];
+ int s2 = dst->op_params[2];
+ int s3 = dst->op_params[3];
+
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) dst->src[0]->data;
+ float * dst_d = (float *) dst->data;
+
+ GGML_TENSOR_UNARY_OP_LOCALS;
+
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst));
+
+ cudaStream_t stream = ctx.stream();
+
+ int64_t sz = (ne00 * ne01 * ne02 * ne03);
+ int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE;
+
+ roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>(
+ src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/roll.cuh b/llama.cpp/ggml/src/ggml-cuda/roll.cuh
new file mode 100644
index 0000000..322d554
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/roll.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_ROLL_BLOCK_SIZE 256
+
+void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/rope.cu b/llama.cpp/ggml/src/ggml-cuda/rope.cu
new file mode 100644
index 0000000..45a49a5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/rope.cu
@@ -0,0 +1,665 @@
+#include "convert.cuh"
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
+#include "rope.cuh"
+
+struct rope_corr_dims {
+ float v[2];
+};
+
+
+struct mrope_sections {
+ int v[4];
+};
+
+static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
+ const float y = (i0 / 2 - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
+// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
+template<bool forward>
+static __device__ void rope_yarn(
+ const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
+ float mscale, float & cos_theta, float & sin_theta) {
+ // Get n-d rotational scaling corrected for extrapolation
+ float theta_interp = freq_scale * theta_extrap;
+ float theta = theta_interp;
+ if (ext_factor != 0.0f) {
+ float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
+ theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+
+ // Get n-d magnitude scaling corrected for interpolation
+ mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
+ }
+ cos_theta = cosf(theta) * mscale;
+ sin_theta = sinf(theta) * mscale;
+ if (!forward) {
+ sin_theta *= -1.0f;
+ }
+}
+
+template <bool forward, bool has_ff, typename T, typename D>
+static __global__ void rope_norm(const T * x,
+ D * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int32_t * pos,
+ const float freq_scale,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float theta_scale,
+ const float * freq_factors,
+ const int64_t * row_indices,
+ const int set_rows_stride) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne00) {
+ return;
+ }
+
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
+
+ const uint32_t i3 = row_dst / (ne01 * ne02);
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+ int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
+ const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
+ // Fusion optimization: ROPE + VIEW + SET_ROWS.
+ // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
+ if (set_rows_stride != 0) {
+ idst = i1 * s1 + i0;
+ idst += row_indices[i2] * set_rows_stride;
+ }
+
+ const auto & store_coaelsced = [&](float x0, float x1) {
+ if constexpr (std::is_same_v<float, D>) {
+ float2 v = make_float2(x0, x1);
+ ggml_cuda_memcpy_1<8>(dst + idst, &v);
+ } else if constexpr (std::is_same_v<half, D>) {
+ half2 v = make_half2(x0, x1);
+ ggml_cuda_memcpy_1<4>(dst + idst, &v);
+ }
+ };
+ if (i0 >= n_dims) {
+ store_coaelsced(x[ix + 0], x[ix + 1]);
+ return;
+ }
+
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
+
+ const float x0 = x[ix + 0];
+ const float x1 = x[ix + 1];
+
+ store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
+}
+
+template <bool forward, bool has_ff, typename T, typename D>
+static __global__ void rope_neox(const T * x,
+ D * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int32_t * pos,
+ const float freq_scale,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float theta_scale,
+ const float * freq_factors,
+ const int64_t * row_indices,
+ const int set_rows_stride) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne00) {
+ return;
+ }
+
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
+
+ const uint32_t i3 = row_dst / (ne01 * ne02);
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
+
+ // Fusion optimization: ROPE + VIEW + SET_ROWS.
+ // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
+ if (set_rows_stride != 0) {
+ idst = i1 * s1 + i0 / 2;
+ idst += row_indices[i2] * set_rows_stride;
+ }
+
+ if (i0 >= n_dims) {
+ dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
+ dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
+
+ return;
+ }
+
+ const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
+
+ const float x0 = x[ix + 0];
+ const float x1 = x[ix + n_dims/2];
+
+ dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
+ dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
+}
+
+template <bool forward, bool has_ff, typename T>
+static __global__ void rope_multi(const T * x,
+ T * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int32_t * pos,
+ const float freq_scale,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float theta_scale,
+ const float * freq_factors,
+ const mrope_sections sections,
+ const bool is_imrope) {
+ const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne00) {
+ return;
+ }
+
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
+
+ const uint32_t i3 = row_dst / (ne01 * ne02);
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
+
+ if (i0 >= n_dims) {
+ dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
+ dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
+
+ return;
+ }
+
+ const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
+ const int sec_w = sections.v[1] + sections.v[0];
+ const int sector = (i0 / 2) % sect_dims;
+
+ float theta_base = 0.0;
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
+ theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
+ } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
+ theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
+ } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
+ theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
+ } else {
+ theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
+ }
+ } else {
+ if (sector < sections.v[0]) {
+ theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
+ } else if (sector >= sections.v[0] && sector < sec_w) {
+ theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
+ } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
+ theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
+ } else if (sector >= sec_w + sections.v[2]) {
+ theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
+ }
+ }
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
+
+ const float x0 = x[ix + 0];
+ const float x1 = x[ix + n_dims/2];
+
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
+ dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
+}
+
+template <bool forward, bool has_ff, typename T>
+static __global__ void rope_vision(const T * x,
+ T * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int32_t * pos,
+ const float freq_scale,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float theta_scale,
+ const float * freq_factors,
+ const mrope_sections sections) {
+ const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
+
+ if (i0 >= ne00) {
+ return;
+ }
+
+ const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
+
+ const uint32_t i3 = row_dst / (ne01 * ne02);
+ const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
+ const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
+
+ int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
+ const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
+
+ const int sect_dims = sections.v[0] + sections.v[1];
+ const int sec_w = sections.v[1] + sections.v[0];
+ const int sector = (i0 / 2) % sect_dims;
+
+ float theta_base = 0.0;
+ if (sector < sections.v[0]) {
+ const int p = sector;
+ theta_base = pos[i2] * powf(theta_scale, p);
+ } else if (sector >= sections.v[0] && sector < sec_w) {
+ const int p = sector - sections.v[0];
+ theta_base = pos[i2 + ne02] * powf(theta_scale, p);
+ }
+
+ const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
+
+ float cos_theta;
+ float sin_theta;
+
+ rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
+
+ const float x0 = x[ix + 0];
+ const float x1 = x[ix + n_dims];
+
+ dst[idst + 0] = x0*cos_theta - x1*sin_theta;
+ dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
+}
+
+template <bool forward, typename T, typename D>
+static void rope_norm_cuda(const T * x,
+ D * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int nr,
+ const int32_t * pos,
+ const float freq_scale,
+ const float freq_base,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float * freq_factors,
+ const int64_t * row_indices,
+ const int set_rows_stride,
+ cudaStream_t stream) {
+ GGML_ASSERT(ne00 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
+ } else {
+ rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
+ }
+}
+
+template <bool forward, typename T, typename D>
+static void rope_neox_cuda(const T * x,
+ D * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int nr,
+ const int32_t * pos,
+ const float freq_scale,
+ const float freq_base,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float * freq_factors,
+ const int64_t * row_indices,
+ const int set_rows_stride,
+ cudaStream_t stream) {
+ GGML_ASSERT(ne00 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
+ } else {
+ rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
+ }
+}
+
+template <bool forward, typename T>
+static void rope_multi_cuda(const T * x,
+ T * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int nr,
+ const int32_t * pos,
+ const float freq_scale,
+ const float freq_base,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float * freq_factors,
+ const mrope_sections sections,
+ const bool is_imrope,
+ cudaStream_t stream) {
+ GGML_ASSERT(ne00 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+
+ const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
+ } else {
+ rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
+ }
+}
+
+template <bool forward, typename T>
+static void rope_vision_cuda(const T * x,
+ T * dst,
+ const int ne00,
+ const int ne01,
+ const int ne02,
+ const int s01,
+ const int s02,
+ const int s03,
+ const int s1,
+ const int s2,
+ const int s3,
+ const int n_dims,
+ const int nr,
+ const int32_t * pos,
+ const float freq_scale,
+ const float freq_base,
+ const float ext_factor,
+ const float attn_factor,
+ const rope_corr_dims corr_dims,
+ const float * freq_factors,
+ const mrope_sections sections,
+ cudaStream_t stream) {
+ GGML_ASSERT(ne00 % 2 == 0);
+ const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
+ const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
+ const dim3 block_nums(nr, n_blocks_x, 1);
+ // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
+ // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+
+ if (freq_factors == nullptr) {
+ rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
+ } else {
+ rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
+ x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
+ attn_factor, corr_dims, theta_scale, freq_factors, sections);
+ }
+}
+
+template <bool forward>
+void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
+ ggml_tensor * dst,
+ const ggml_tensor * set_rows = nullptr) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+
+ void * dst_d = dst->data;
+ const int64_t * row_indices = nullptr;
+ ggml_type dst_type = dst->type;
+ int set_rows_stride = 0;
+
+ if (set_rows != nullptr) {
+ GGML_ASSERT(forward);
+ dst_d = set_rows->data;
+ row_indices = (const int64_t *) set_rows->src[1]->data;
+ dst_type = set_rows->type;
+ set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
+ }
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ // When not fused, src0 and dst types must match
+ // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
+ GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
+
+ const int64_t ne00 = src0->ne[0]; // head dims
+ const int64_t ne01 = src0->ne[1]; // num heads
+ const int64_t ne02 = src0->ne[2]; // num heads
+ const int64_t nr = ggml_nrows(src0);
+
+ const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
+ const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
+ const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
+
+ const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
+ const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
+ const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
+
+ //const int n_past = ((int32_t *) dst->op_params)[0];
+ const int n_dims = ((int32_t *) dst->op_params)[1];
+ const int mode = ((int32_t *) dst->op_params)[2];
+ //const int n_ctx = ((int32_t *) dst->op_params)[3];
+ const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+ mrope_sections sections;
+
+ // RoPE alteration for extended context
+ float freq_base;
+ float freq_scale;
+ float ext_factor;
+ float attn_factor;
+ float beta_fast;
+ float beta_slow;
+
+ memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+ memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+ memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+ memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+ memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+ memcpy(&sections.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
+
+ const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+ if (is_mrope) {
+ GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
+ }
+
+ if (is_vision) {
+ GGML_ASSERT(n_dims == ne00/2);
+ }
+
+ const int32_t * pos = (const int32_t *) src1_d;
+
+ const float * freq_factors = nullptr;
+ if (src2 != nullptr) {
+ freq_factors = (const float *) src2->data;
+ }
+
+ rope_corr_dims corr_dims;
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
+
+ // compute
+ if (is_neox) {
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
+ rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+ set_rows_stride, stream);
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
+ rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+ set_rows_stride, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
+ rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+ set_rows_stride, stream);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ } else if (is_mrope && !is_vision) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+ corr_dims, freq_factors, sections, is_imrope, stream);
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+ corr_dims, freq_factors, sections, is_imrope, stream);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ } else if (is_vision) {
+ if (src0->type == GGML_TYPE_F32) {
+ rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+ corr_dims, freq_factors, sections, stream);
+ } else if (src0->type == GGML_TYPE_F16) {
+ rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
+ s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
+ corr_dims, freq_factors, sections, stream);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ } else {
+ if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
+ rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+ set_rows_stride, stream);
+ } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
+ rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+ set_rows_stride, stream);
+ } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
+ rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
+ s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
+ ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
+ set_rows_stride, stream);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_rope_impl<true>(ctx, dst);
+}
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_rope_impl<false>(ctx, dst);
+}
+
+void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
+ ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/rope.cuh b/llama.cpp/ggml/src/ggml-cuda/rope.cuh
new file mode 100644
index 0000000..72af086
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/rope.cuh
@@ -0,0 +1,9 @@
+#include "common.cuh"
+
+#define CUDA_ROPE_BLOCK_SIZE 256
+
+void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
diff --git a/llama.cpp/ggml/src/ggml-cuda/scale.cu b/llama.cpp/ggml/src/ggml-cuda/scale.cu
new file mode 100644
index 0000000..0ddeff6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/scale.cu
@@ -0,0 +1,34 @@
+#include "scale.cuh"
+
+#define MAX_GRIDDIM_X 0x7FFFFFFF
+
+static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
+ int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
+ int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
+
+ for (int64_t i = tid; i < nelements; i += stride) {
+ dst[i] = scale * x[i] + bias;
+ }
+}
+
+static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
+ const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
+ scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
+}
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float scale;
+ float bias;
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
+
+ scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/scale.cuh b/llama.cpp/ggml/src/ggml-cuda/scale.cuh
new file mode 100644
index 0000000..8ff75c8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/scale.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_SCALE_BLOCK_SIZE 256
+
+void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/set-rows.cu b/llama.cpp/ggml/src/ggml-cuda/set-rows.cu
new file mode 100644
index 0000000..631de7e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/set-rows.cu
@@ -0,0 +1,330 @@
+#include "set-rows.cuh"
+#include "cpy-utils.cuh"
+
+typedef void (*set_rows_kernel_t)(const char * src, char * dst);
+
+// Generic quantized set_rows kernel template
+template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
+static __global__ void k_set_rows_quant(const float * __restrict__ src0,
+ const idx_t * __restrict__ src1,
+ block_type * __restrict__ dst,
+ const int64_t ne_total,
+ const int64_t ne10,
+ const int64_t ne11,
+ const int64_t ne12,
+ const int64_t ne13,
+ const int64_t s01,
+ const int64_t s02,
+ const int64_t s03,
+ const int64_t s10,
+ const int64_t s11,
+ const int64_t s12,
+ const int64_t s1,
+ const int64_t s2,
+ const int64_t s3,
+ const uint3 ne00,
+ const uint3 ne01,
+ const uint3 ne02,
+ const uint3 ne11_fd,
+ const uint3 ne12_fd) {
+ const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+
+ if (i >= ne_total) {
+ return;
+ }
+
+ const int64_t i_base = i * qk;
+ uint32_t tmp = (uint32_t) i_base;
+ uint2 div_mod;
+
+ div_mod = fast_div_modulo(tmp, ne00);
+ const int64_t i00 = div_mod.y;
+ tmp = div_mod.x;
+
+ div_mod = fast_div_modulo(tmp, ne01);
+ const int64_t i01 = div_mod.y;
+ tmp = div_mod.x;
+
+ div_mod = fast_div_modulo(tmp, ne02);
+ const int64_t i02 = div_mod.y;
+ const int64_t i03 = div_mod.x;
+
+ const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
+ const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
+ const int64_t i10 = i01;
+
+ const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
+
+ const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
+ block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type);
+
+ const float * src_block = src0_row + i00;
+ block_type * dst_block = dst_row_ptr + i00 / qk;
+
+ quantize_func(src_block, dst_block);
+
+ GGML_UNUSED(ne10);
+ GGML_UNUSED(ne11);
+ GGML_UNUSED(ne12);
+ GGML_UNUSED(ne13);
+}
+
+// Template dispatch function for quantized set_rows
+template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
+static void set_rows_cuda_quant(
+ const float * src0_d, const idx_t * src1_d, block_type * dst_d,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+ const size_t nb01, const size_t nb02, const size_t nb03,
+ const size_t nb10, const size_t nb11, const size_t nb12,
+ const size_t nb1, const size_t nb2, const size_t nb3,
+ cudaStream_t stream) {
+
+ GGML_ASSERT(ne00 % qk == 0);
+ const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
+ const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
+ const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
+ const dim3 grid_size(num_blocks);
+
+ const int64_t s01 = nb01/sizeof(float);
+ const int64_t s02 = nb02/sizeof(float);
+ const int64_t s03 = nb03/sizeof(float);
+ const int64_t s10 = nb10/sizeof(idx_t);
+ const int64_t s11 = nb11/sizeof(idx_t);
+ const int64_t s12 = nb12/sizeof(idx_t);
+ const int64_t s1 = nb1;
+ const int64_t s2 = nb2;
+ const int64_t s3 = nb3;
+
+ if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
+ const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
+ const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
+ const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
+ const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
+ const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
+
+ k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
+ src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
+ ne01_fd, ne02_fd, ne11_fd, ne12_fd);
+ }
+}
+
+template <typename src_t, typename idx_t, typename dst_t>
+static __global__ void k_set_rows(const src_t * __restrict__ src0,
+ const idx_t * __restrict__ src1,
+ dst_t * __restrict__ dst,
+ const int64_t ne_total,
+ const int64_t ne10,
+ const int64_t ne11,
+ const int64_t ne12,
+ const int64_t ne13,
+ const int64_t s01,
+ const int64_t s02,
+ const int64_t s03,
+ const int64_t s10,
+ const int64_t s11,
+ const int64_t s12,
+ const int64_t s1,
+ const int64_t s2,
+ const int64_t s3,
+ const uint3 ne00,
+ const uint3 ne01,
+ const uint3 ne02,
+ const uint3 ne11_fd,
+ const uint3 ne12_fd) {
+ const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
+
+ if (i >= ne_total) {
+ return;
+ }
+
+ uint32_t tmp = (uint32_t) i;
+ uint2 div_mod;
+
+ div_mod = fast_div_modulo(tmp, ne00);
+ const int64_t i00 = div_mod.y;
+ tmp = div_mod.x;
+
+ div_mod = fast_div_modulo(tmp, ne01);
+ const int64_t i01 = div_mod.y;
+ tmp = div_mod.x;
+
+ div_mod = fast_div_modulo(tmp, ne02);
+ const int64_t i02 = div_mod.y;
+ const int64_t i03 = div_mod.x;
+
+ const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
+ const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
+ const int64_t i10 = i01;
+
+ const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
+
+ const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03;
+ dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3;
+
+ dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
+
+ GGML_UNUSED(ne10);
+ GGML_UNUSED(ne11);
+ GGML_UNUSED(ne12);
+ GGML_UNUSED(ne13);
+}
+
+template<typename src_t, typename idx_t, typename dst_t>
+static void set_rows_cuda(
+ const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
+ const size_t nb01, const size_t nb02, const size_t nb03,
+ const size_t nb10, const size_t nb11, const size_t nb12,
+ const size_t nb1, const size_t nb2, const size_t nb3,
+ cudaStream_t stream) {
+
+ const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
+ const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE;
+ const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE);
+ const dim3 grid_size(num_blocks);
+
+
+ const int64_t s01 = nb01/sizeof(src_t);
+ const int64_t s02 = nb02/sizeof(src_t);
+ const int64_t s03 = nb03/sizeof(src_t);
+ const int64_t s10 = nb10/sizeof(idx_t);
+ const int64_t s11 = nb11/sizeof(idx_t);
+ const int64_t s12 = nb12/sizeof(idx_t);
+ const int64_t s1 = nb1/sizeof(dst_t);
+ const int64_t s2 = nb2/sizeof(dst_t);
+ const int64_t s3 = nb3/sizeof(dst_t);
+
+ if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
+ const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
+ const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
+ const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
+ const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
+ const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
+
+ k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
+ s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
+ ne11_fd, ne12_fd);
+ }
+}
+
+template<typename src_t, typename idx_t>
+static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const src_t * src0_d = (const src_t *)src0->data;
+ const idx_t * src1_d = (const idx_t *)src1->data;
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ cudaStream_t stream = ctx.stream();
+
+
+ if (dst->type == GGML_TYPE_F32) {
+ set_rows_cuda(
+ src0_d, src1_d, (float*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_F16) {
+ set_rows_cuda(
+ src0_d, src1_d, (half*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_BF16) {
+ set_rows_cuda(
+ src0_d, src1_d, (nv_bfloat16*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_Q4_0) {
+ set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
+ src0_d, src1_d, (block_q4_0*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_Q4_1) {
+ set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
+ src0_d, src1_d, (block_q4_1*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_Q5_0) {
+ set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
+ src0_d, src1_d, (block_q5_0*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_Q5_1) {
+ set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
+ src0_d, src1_d, (block_q5_1*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_Q8_0) {
+ set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
+ src0_d, src1_d, (block_q8_0*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else if (dst->type == GGML_TYPE_IQ4_NL) {
+ set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
+ src0_d, src1_d, (block_iq4_nl*)dst->data,
+ ne00, ne01, ne02, ne03,
+ ne10, ne11, ne12, ne13,
+ nb01, nb02, nb03,
+ nb10, nb11, nb12,
+ nb1, nb2, nb3,
+ stream
+ );
+ } else {
+ GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
+ }
+}
+
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
+
+ if (src1->type == GGML_TYPE_I64) {
+ set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
+ } else {
+ set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh b/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh
new file mode 100644
index 0000000..c140c08
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/set-rows.cuh
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "common.cuh"
+
+#define CUDA_SET_ROWS_BLOCK_SIZE 256
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/set.cu b/llama.cpp/ggml/src/ggml-cuda/set.cu
new file mode 100644
index 0000000..04bfe07
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/set.cu
@@ -0,0 +1,39 @@
+#include "set.cuh"
+#include "cpy.cuh"
+
+void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
+ GGML_ASSERT(src1->type == src0->type);
+ GGML_ASSERT(dst ->type == src0->type);
+
+ GGML_ASSERT(ggml_is_contiguous(dst));
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
+ const size_t nb1 = ((int32_t *) dst->op_params)[0];
+ const size_t nb2 = ((int32_t *) dst->op_params)[1];
+ const size_t nb3 = ((int32_t *) dst->op_params)[2];
+ const size_t offset = ((int32_t *) dst->op_params)[3];
+ const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
+
+ if (!inplace) {
+ ggml_cuda_cpy(ctx, src0, dst);
+ }
+
+ ggml_tensor dst_view = *dst;
+ dst_view.data = (void *)((char *)dst->data + offset);
+ dst_view.ne[0] = src1->ne[0];
+ dst_view.ne[1] = src1->ne[1];
+ dst_view.ne[2] = src1->ne[2];
+ dst_view.ne[3] = src1->ne[3];
+
+ dst_view.nb[0] = ggml_element_size(dst);
+ dst_view.nb[1] = nb1;
+ dst_view.nb[2] = nb2;
+ dst_view.nb[3] = nb3;
+
+ ggml_cuda_cpy(ctx, src1, &dst_view);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/set.cuh b/llama.cpp/ggml/src/ggml-cuda/set.cuh
new file mode 100644
index 0000000..dd09529
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/set.cuh
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "common.cuh"
+
+#define CUDA_SET_BLOCK_SIZE 256
+
+void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/softcap.cu b/llama.cpp/ggml/src/ggml-cuda/softcap.cu
new file mode 100644
index 0000000..40dfe45
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/softcap.cu
@@ -0,0 +1,34 @@
+#include "softcap.cuh"
+
+static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = tanhf(scale * x[i]) * softcap;
+}
+
+static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE;
+ softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k);
+}
+
+// fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src) {
+ const ggml_tensor * src0 = src->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float scale;
+ float softcap;
+ memcpy(&scale, (float *) src->op_params + 0, sizeof(float));
+ memcpy(&softcap, (float *) dst->op_params + 0, sizeof(float));
+
+ softcap_f32_cuda(src0_d, dst_d, scale, softcap, ggml_nelements(src0), stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/softcap.cuh b/llama.cpp/ggml/src/ggml-cuda/softcap.cuh
new file mode 100644
index 0000000..6d34fb2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/softcap.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_SOFTCAP_BLOCK_SIZE 256
+
+void ggml_cuda_op_softcap(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * src);
diff --git a/llama.cpp/ggml/src/ggml-cuda/softmax.cu b/llama.cpp/ggml/src/ggml-cuda/softmax.cu
new file mode 100644
index 0000000..dc06d06
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/softmax.cu
@@ -0,0 +1,472 @@
+#include "common.cuh"
+#include "ggml.h"
+#include "softmax.cuh"
+
+#ifdef GGML_USE_HIP
+#include <hip/hip_cooperative_groups.h>
+#else
+#include <cooperative_groups.h>
+#include <cooperative_groups/reduce.h>
+#endif // GGML_USE_HIP
+
+#include <cstdint>
+#include <utility>
+
+template <typename T>
+static __device__ __forceinline__ float t2f32(T val) {
+ return (float) val;
+}
+
+template <>
+__device__ float __forceinline__ t2f32<half>(half val) {
+ return __half2float(val);
+}
+
+struct soft_max_params {
+
+ int64_t nheads;
+ uint32_t n_head_log2;
+ int64_t ncols;
+ int64_t nrows_x;
+ int64_t nrows_y;
+ int64_t ne00;
+ int64_t ne01;
+ int64_t ne02;
+ int64_t ne03;
+ int64_t nb11;
+ int64_t nb12;
+ int64_t nb13;
+
+ int64_t ne12;
+ int64_t ne13;
+ float scale;
+ float max_bias;
+ float m0;
+ float m1;
+};
+
+// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
+// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template <bool use_shared, int ncols_template, int block_size_template, typename T>
+static __global__ void soft_max_f32(
+ const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
+
+ const int tid = threadIdx.x;
+
+ const int64_t i03 = blockIdx.z;
+ const int64_t i02 = blockIdx.y;
+ const int64_t i01 = blockIdx.x;
+
+ //TODO: noncontigous inputs/outputs
+ const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
+
+ const int64_t i11 = i01;
+ const int64_t i12 = i02 % p.ne12;
+ const int64_t i13 = i03 % p.ne13;
+
+ x += int64_t(rowx)*ncols;
+ mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
+ dst += int64_t(rowx)*ncols;
+
+ const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
+
+ const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
+
+ extern __shared__ float data_soft_max_f32[];
+ float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
+ // shared memory buffer to cache values between iterations:
+ float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
+
+ float max_val = sinks ? sinks[i02] : -INFINITY;
+
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
+ const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
+
+ vals[col] = val;
+ max_val = max(max_val, val);
+ }
+
+ // find the max value in the block
+ max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw);
+
+ float tmp = 0.0f; // partial sum
+
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ break;
+ }
+
+ const float val = expf(vals[col] - max_val);
+ tmp += val;
+ vals[col] = val;
+ }
+
+ // find the sum of exps in the block
+ tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw);
+
+ if (sinks) {
+ tmp += expf(sinks[i02] - max_val);
+ }
+
+ const float inv_sum = 1.0f / tmp;
+
+#pragma unroll
+ for (int col0 = 0; col0 < ncols; col0 += block_size) {
+ const int col = col0 + tid;
+
+ if (ncols_template == 0 && col >= ncols) {
+ return;
+ }
+
+ dst[col] = vals[col] * inv_sum;
+ }
+}
+
+// TODO: Template to allow keeping ncols in registers if they fit
+static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
+ float * __restrict__ dst,
+ float * __restrict__ tmp_maxs,
+ float * __restrict__ tmp_sums,
+ const soft_max_params p) {
+ namespace cg = cooperative_groups;
+
+ const cg::grid_group g = cg::this_grid();
+
+ const int tid = threadIdx.x;
+ const int col_start = blockIdx.x * blockDim.x + tid;
+ const int n_elem_per_thread = 4;
+
+ float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
+ float local_max = -INFINITY;
+ const int step_size = gridDim.x * blockDim.x;
+ __shared__ float shared_vals[32];
+
+ // Compute thread-local max
+ for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ local_max = fmaxf(local_max, local_vals[i]);
+ }
+ col += step_size * n_elem_per_thread;
+ }
+
+ // Compute CTA-level max
+ local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
+
+ // Store CTA-level max to GMEM
+ if (tid == 0) {
+ tmp_maxs[blockIdx.x] = local_max;
+ }
+ g.sync();
+
+ // Compute compute global max from CTA-level maxs
+ assert(gridDim.x < blockDim.x); // currently we only support this case
+ if (tid < gridDim.x) {
+ local_max = tmp_maxs[tid];
+ } else {
+ local_max = -INFINITY;
+ }
+ local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals);
+
+ // Compute softmax dividends, accumulate divisor
+ float tmp_expf = 0.0f;
+ for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ if (idx < p.ncols) {
+ const float tmp = expf(local_vals[i] - local_max);
+ tmp_expf += tmp;
+ dst[idx] = tmp;
+ }
+ }
+ col += step_size * n_elem_per_thread;
+ }
+
+ // Reduce divisor within CTA
+ tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
+
+ // Store CTA-level sum to GMEM
+ if (tid == 0) {
+ tmp_sums[blockIdx.x] = tmp_expf;
+ }
+ g.sync();
+
+ // Compute global sum from CTA-level sums
+ if (tid < gridDim.x) {
+ tmp_expf = tmp_sums[tid];
+ } else {
+ tmp_expf = 0.0f;
+ }
+ tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals);
+
+ // Divide dividend by global sum + store data
+ for (int col = col_start; col < p.ncols;) {
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_elem_per_thread; i++) {
+ const int idx = col + i * step_size;
+ if (idx < p.ncols) {
+ dst[idx] = local_vals[i] / tmp_expf;
+ }
+ }
+ col += step_size * n_elem_per_thread;
+ }
+}
+
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+static __global__ void soft_max_back_f32(
+ const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {
+ const int tid = threadIdx.x;
+ const int rowx = blockIdx.x;
+
+ grad += int64_t(rowx)*ncols;
+ dstf += int64_t(rowx)*ncols;
+ dst += int64_t(rowx)*ncols;
+
+ float dgf_dot = 0.0f; // dot product of dst from forward pass and gradients
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ dgf_dot += dstf[col]*grad[col];
+ }
+
+ dgf_dot = warp_reduce_sum(dgf_dot);
+
+ for (int col = tid; col < ncols; col += WARP_SIZE) {
+ dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
+ }
+}
+
+template<int... Ns, typename T>
+static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
+ const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
+{
+ const int id = ggml_cuda_get_device();
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+ auto launch_kernel = [=](auto I) -> bool {
+ constexpr int ncols = decltype(I)::value;
+ constexpr int block = (ncols > 1024 ? 1024 : ncols);
+
+ if (p.ncols == ncols) {
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
+ soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, mask, sinks, dst, p);
+ return true;
+ }
+ return false;
+ };
+
+ // unary fold over launch_kernel
+ if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
+ return;
+ }
+
+ //default case
+ CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
+}
+
+__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
+ float * __restrict__ dst,
+ float * __restrict__ tmp_maxs,
+ float * __restrict__ tmp_sums,
+ const soft_max_params p)
+// We loop over all instead of parallelizing across gridDim.y as cooperative groups
+// currently only support synchronizing the complete grid if not launched as a cluster group
+// (which requires CC > 9.0)
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
+// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
+{
+ for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
+ soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
+ tmp_sums, p);
+ }
+}
+
+template <typename T>
+static void soft_max_f32_cuda(const float * x,
+ const T * mask,
+ const float * sinks,
+ float * dst,
+ const soft_max_params & params,
+ cudaStream_t stream,
+ [[maybe_unused]] ggml_backend_cuda_context & ctx) {
+ int nth = WARP_SIZE;
+ const int64_t ncols_x = params.ncols;
+
+ while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+ const dim3 block_dims(nth, 1, 1);
+ const dim3 block_nums(params.ne01, params.ne02, params.ne03);
+ const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
+ static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+
+ const int id = ggml_cuda_get_device();
+ const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+
+
+ if (nbytes_shared <= smpbo) {
+ launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
+ } else {
+ // Parallelize across SMs for top-p/dist-sampling
+ // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and
+ // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
+ if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
+ ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
+ params.scale == 1.0f && params.max_bias == 0.0f) {
+ ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+ ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
+
+ void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
+ (void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(&params) };
+ CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
+ dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
+ dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
+ } else {
+ const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
+ soft_max_f32<false, 0, 0>
+ <<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
+ }
+ }
+}
+
+static void soft_max_back_f32_cuda(
+ const float * grad, const float * dstf, float * dst,
+ const int ncols, const int nrows, const float scale, cudaStream_t stream) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ const dim3 block_nums(nrows, 1, 1);
+
+ soft_max_back_f32<<<block_nums, block_dims, 0, stream>>>(grad, dstf, dst, ncols, scale);
+}
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ const float * src0_d = (const float *) src0->data;
+ const void * src1_d = src1 ? (const void *) src1->data : nullptr;
+ const void * src2_d = src2 ? (const void *) src2->data : nullptr;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
+ const int64_t nrows_x = ggml_nrows(src0);
+ const int64_t nrows_y = src0->ne[1];
+
+ const int64_t ne00 = src0->ne[0];
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+ const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+
+ const int64_t nb11 = src1 ? src1->nb[1] : 1;
+ const int64_t nb12 = src1 ? src1->nb[2] : 1;
+ const int64_t nb13 = src1 ? src1->nb[3] : 1;
+
+ const int64_t ne12 = src1 ? src1->ne[2] : 1;
+ const int64_t ne13 = src1 ? src1->ne[3] : 1;
+
+ const uint32_t n_head = src0->ne[2];
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+
+ soft_max_params params = {};
+ params.nheads = src0->ne[2];
+ params.n_head_log2 = n_head_log2;
+ params.ncols = ne00;
+ params.nrows_x = nrows_x;
+ params.nrows_y = nrows_y;
+ params.ne00 = src0->ne[0];
+ params.ne01 = src0->ne[1];
+ params.ne02 = src0->ne[2];
+ params.ne03 = src0->ne[3];
+ params.nb11 = nb11;
+ params.nb12 = nb12;
+ params.nb13 = nb13;
+ params.ne12 = ne12;
+ params.ne13 = ne13;
+ params.scale = scale;
+ params.max_bias = max_bias;
+ params.m0 = m0;
+ params.m1 = m1;
+
+ if (use_f16) {
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
+ } else {
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx);
+ }
+}
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0]; // grad
+ const ggml_tensor * src1 = dst->src[1]; // forward pass output
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
+
+ GGML_ASSERT(max_bias == 0.0f);
+
+ soft_max_back_f32_cuda(src0_d, src1_d, dst_d, ncols, nrows, scale, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/softmax.cuh b/llama.cpp/ggml/src/ggml-cuda/softmax.cuh
new file mode 100644
index 0000000..93dfee8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/softmax.cuh
@@ -0,0 +1,7 @@
+#include "common.cuh"
+
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
+
+void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_soft_max_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/solve_tri.cu b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cu
new file mode 100644
index 0000000..177ffc2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cu
@@ -0,0 +1,275 @@
+#include "common.cuh"
+#include "ggml.h"
+#include "solve_tri.cuh"
+
+#define MAX_N_FAST 64
+#define MAX_K_FAST 32
+
+static __global__ void get_batch_pointers(const float * A,
+ float * X,
+ const float ** A_ptrs,
+ float ** X_ptrs,
+ int64_t ne02,
+ int64_t total_batches,
+ size_t s02,
+ size_t s03,
+ size_t s2,
+ size_t s3) {
+ const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= total_batches) {
+ return;
+ }
+
+ const int64_t i3 = idx / ne02;
+ const int64_t i2 = idx % ne02;
+
+ A_ptrs[idx] = A + i3 * s03 + i2 * s02;
+ X_ptrs[idx] = X + i3 * s3 + i2 * s2;
+}
+
+static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx,
+ const float * A,
+ const float * B,
+ float * X,
+ int n,
+ int k,
+ int64_t ne02,
+ int64_t ne03,
+ size_t s02,
+ size_t s03,
+ size_t s12,
+ size_t s13,
+ size_t s2,
+ size_t s3,
+ cudaStream_t stream) {
+ const float alpha = 1.0f;
+ const int64_t total_batches = ne02 * ne03;
+ if (total_batches == 0) {
+ return;
+ }
+
+ // Bulk copy B -> X (contiguous tensors)
+ if (X != B) {
+ const int64_t total_elements_BX = n * k * total_batches;
+ CUDA_CHECK(cudaMemcpyAsync(X, B, total_elements_BX * sizeof(float), cudaMemcpyDeviceToDevice, stream));
+ }
+
+ const int id = ggml_cuda_get_device();
+
+ ggml_cuda_pool_alloc<const float *> A_ptrs_alloc(ctx.pool(id), total_batches);
+ ggml_cuda_pool_alloc<float *> X_ptrs_alloc(ctx.pool(id), total_batches);
+
+ const float ** A_ptrs_dev = A_ptrs_alloc.get();
+ float ** X_ptrs_dev = X_ptrs_alloc.get();
+
+ get_batch_pointers<<<(total_batches + 255) / 256, 256, 0, stream>>>(A, X, A_ptrs_dev, X_ptrs_dev, ne02,
+ total_batches, s02, s03, s2, s3);
+
+ CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
+
+ // Yes, this is necessary, without this we get RMSE errors
+ CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_DEFAULT_MATH));
+ CUBLAS_CHECK(cublasStrsmBatched(ctx.cublas_handle(id), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_UPPER, CUBLAS_OP_N,
+ CUBLAS_DIAG_NON_UNIT, k, n, &alpha, A_ptrs_dev, n, X_ptrs_dev, k, total_batches));
+
+ // revert to standard mode from common.cuh
+ CUBLAS_CHECK(cublasSetMathMode(ctx.cublas_handle(id), CUBLAS_TF32_TENSOR_OP_MATH));
+
+ GGML_UNUSED_VARS(s12, s13);
+}
+
+// ======================
+// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
+// ======================
+// When ncols_template == 0 the bounds for the loops in this function are not
+// known and can't be unrolled. As we want to keep pragma unroll for all other
+// cases we supress the clang transformation warning here.
+#ifdef __clang__
+# pragma clang diagnostic push
+# pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template <int n_template, int k_template>
+static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
+ const float * __restrict__ B,
+ float * __restrict__ X,
+ const uint3 ne02,
+ const size_t nb02,
+ const size_t nb03,
+ const size_t nb12,
+ const size_t nb13,
+ const size_t nb2,
+ const size_t nb3,
+ const int n_arg,
+ const int k_arg) {
+ const int n = n_template == 0 ? n_arg : n_template;
+ const int k = k_template == 0 ? k_arg : k_template;
+
+ const int batch_idx = blockIdx.x;
+ const int lane = threadIdx.x;
+ const int col_idx = threadIdx.y;
+
+ if (col_idx >= k) {
+ return;
+ }
+
+ const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
+ const int64_t i02 = i02_i03.y;
+ const int64_t i03 = i02_i03.x;
+
+ const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
+ const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
+ float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
+
+ __shared__ float sA[MAX_N_FAST * MAX_N_FAST];
+
+ const int offset = threadIdx.x + threadIdx.y * blockDim.x;
+
+#pragma unroll
+ for (int i = 0; i < n * n; i += k * WARP_SIZE) {
+ const int i0 = i + offset;
+ if (i0 < n * n) {
+ sA[i0] = A_batch[i0];
+ }
+ }
+
+ __syncthreads();
+
+ float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
+ float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
+
+ const int half = WARP_SIZE;
+ const int nrows_low = (n < half) ? n : half;
+
+#pragma unroll
+ for (int row = 0; row < nrows_low; ++row) {
+ float sum = 0.0f;
+ if (lane < row) {
+ sum += sA[row * n + lane] * x_low;
+ }
+ sum = warp_reduce_sum(sum);
+
+ if (lane == row) {
+ x_low = (x_low - sum) / sA[row * n + row];
+ }
+ }
+
+#pragma unroll
+ for (int row = half; row < n; ++row) {
+ float sum = sA[row * n + lane] * x_low;
+ const int j = half + lane;
+ if (j < row) {
+ sum += sA[row * n + j] * x_high;
+ }
+ sum = warp_reduce_sum(sum);
+
+ if (lane == row - half) {
+ x_high = (x_high - sum) / sA[row * n + row];
+ }
+ }
+
+#pragma unroll
+ for (int rr = 0; rr < 2; ++rr) {
+ const int row = rr * WARP_SIZE + lane;
+ if (row < n) {
+ const float val = (row < half) ? x_low : x_high;
+ X_batch[row * k + col_idx] = val;
+ }
+ }
+}
+#ifdef __clang__
+# pragma clang diagnostic pop
+#endif // __clang__
+
+static void solve_tri_f32_cuda(const float * A,
+ const float * B,
+ float * X,
+ int n,
+ int k,
+ int64_t ne02,
+ int64_t ne03,
+ size_t nb02,
+ size_t nb03,
+ size_t nb12,
+ size_t nb13,
+ size_t nb2,
+ size_t nb3,
+ cudaStream_t stream) {
+ const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
+ dim3 threads(WARP_SIZE, k);
+ dim3 grid(ne02 * ne03);
+ if (n == 64) {
+ switch (k) {
+ case 32:
+ solve_tri_f32_fast<64, 32>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 16:
+ solve_tri_f32_fast<64, 16>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 14:
+ solve_tri_f32_fast<64, 14>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 12:
+ solve_tri_f32_fast<64, 12>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 10:
+ solve_tri_f32_fast<64, 10>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 8:
+ solve_tri_f32_fast<64, 8>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 6:
+ solve_tri_f32_fast<64, 6>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 4:
+ solve_tri_f32_fast<64, 4>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 2:
+ solve_tri_f32_fast<64, 2>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ case 1:
+ solve_tri_f32_fast<64, 1>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
+ break;
+ default:
+ solve_tri_f32_fast<0, 0>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
+ }
+ } else { // run general case
+ solve_tri_f32_fast<0, 0>
+ <<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
+ }
+}
+
+void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0]; // A (n×n, lower triangular)
+ const ggml_tensor * src1 = dst->src[1]; // B (n×k)
+
+ ggml_is_contiguous(src0);
+ ggml_is_contiguous(src1);
+
+ const int64_t n = src0->ne[0];
+ const int64_t k = src1->ne[0];
+ const int64_t ne02 = src0->ne[2];
+ const int64_t ne03 = src0->ne[3];
+
+ if (n <= MAX_N_FAST && k <= MAX_K_FAST) {
+ solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
+ src0->ne[2], src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
+ src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
+ dst->nb[3] / sizeof(float), ctx.stream());
+ } else {
+ solve_tri_f32_cublas(ctx, (const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k,
+ ne02, ne03, src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
+ src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
+ dst->nb[3] / sizeof(float), ctx.stream());
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/solve_tri.cuh b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cuh
new file mode 100644
index 0000000..6399923
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/solve_tri.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu
new file mode 100644
index 0000000..6d5ea70
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cu
@@ -0,0 +1,150 @@
+#include "ssm-conv.cuh"
+
+template <size_t split_d_inner, size_t d_conv>
+static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
+ const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
+ float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
+ const int64_t n_t) {
+ GGML_UNUSED(src0_nb0);
+ const int tid = threadIdx.x;
+ const int bidx = blockIdx.x;
+ const int bidy = blockIdx.y;
+
+ const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
+ const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
+ float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
+
+ const int stride_x = src0_nb1 / sizeof(float);
+ const int stride_w = src1_nb1 / sizeof(float);
+ const int stride_y = dst_nb1 / sizeof(float);
+
+ float x[d_conv] = { 0.0f };
+ float w[d_conv] = { 0.0f };
+
+#pragma unroll
+ for (size_t j = 0; j < d_conv; j++) {
+ w[j] = w_block[tid * stride_w + j];
+ }
+
+ for (int64_t i = 0; i < n_t; i++) {
+ float sumf = 0.0f;
+
+ if (i == 0) {
+ for (size_t j = 0; j < d_conv; j++) {
+ x[j] = x_block[tid * stride_x + j];
+ }
+ } else {
+ x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
+ }
+
+#pragma unroll
+ for (size_t j = 0; j < d_conv; j++) {
+ sumf += x[(i + j) % d_conv] * w[j];
+ }
+ y_block[i * stride_y + tid] = sumf;
+ }
+}
+
+template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
+static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
+ const int src0_nb0, const int src0_nb1, const int src0_nb2,
+ const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
+ const int dst_nb1, const int dst_nb2, const int64_t n_t) {
+ const int tid = threadIdx.x;
+ const int bidx = blockIdx.x;
+ const int bidy = blockIdx.y;
+ const int bidz = blockIdx.z;
+
+ const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
+ bidz * split_n_t * src0_nb0);
+ const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
+ float * y_block =
+ (float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
+
+ const int stride_x = src0_nb1 / sizeof(float);
+ const int stride_w = src1_nb1 / sizeof(float);
+ const int stride_y = dst_nb1 / sizeof(float);
+
+ float x[d_conv] = { 0.0f };
+ float w[d_conv] = { 0.0f };
+
+#pragma unroll
+ for (size_t j = 0; j < d_conv; j++) {
+ w[j] = w_block[tid * stride_w + j];
+ }
+
+#pragma unroll
+ for (int64_t i = 0; i < split_n_t; i++) {
+ if (bidz * split_n_t + i < n_t) {
+ float sumf = 0.0f;
+
+ if (i == 0) {
+ for (size_t j = 0; j < d_conv; j++) {
+ x[j] = x_block[tid * stride_x + j];
+ }
+ } else {
+ x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
+ }
+
+#pragma unroll
+ for (size_t j = 0; j < d_conv; j++) {
+ sumf += x[(i + j) % d_conv] * w[j];
+ }
+ y_block[i * stride_y + tid] = sumf;
+ }
+ }
+}
+
+static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
+ const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
+ const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
+ const int64_t n_s, cudaStream_t stream) {
+ const int threads = 128;
+ GGML_ASSERT(nr % threads == 0);
+
+ auto launch_kernel = [&](auto NC) {
+ constexpr int kNC = decltype(NC)::value;
+ if (n_t <= 32) {
+ const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
+ ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
+ dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+ } else {
+ const int64_t split_n_t = 32;
+ dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
+ ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>(
+ src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
+ }
+ };
+
+ switch (nc) {
+ case 3: launch_kernel(std::integral_constant<int, 3>{}); break;
+ case 4: launch_kernel(std::integral_constant<int, 4>{}); break;
+ case 9: launch_kernel(std::integral_constant<int, 9>{}); break;
+ default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now.");
+ }
+}
+
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0]; // conv_x
+ const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
+
+ const int64_t nc = src1->ne[0]; // d_conv
+ const int64_t nr = src0->ne[1]; // d_inner
+ const int64_t n_t = dst->ne[1]; // tokens per sequence
+ const int64_t n_s = dst->ne[2]; // number of sequences in the batch
+
+ GGML_ASSERT(dst->ne[0] == nr);
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
+ dst->nb[2], nc, nr, n_t, n_s, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cuh b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cuh
new file mode 100644
index 0000000..8e6c1f0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/ssm-conv.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu
new file mode 100644
index 0000000..c1d4e2b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cu
@@ -0,0 +1,342 @@
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+#define USE_CUB
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
+
+#ifdef USE_CUB
+#include <cub/cub.cuh>
+using namespace cub;
+#endif // USE_CUB
+
+#include "ssm-scan.cuh"
+
+// We would like to keep pragma unroll for cases where L_template is not 0,
+// so we suppress the clang transformation warning.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+template <size_t splitD, size_t N, size_t L_template>
+__global__ void __launch_bounds__(splitD, 1)
+ ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2,
+ const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5,
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+ const int64_t s_off, const int64_t d_inner, const int64_t L_param)
+{
+ const size_t L = L_template == 0 ? L_param : L_template;
+ const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+ const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float));
+ const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float));
+ const float *A_block = (const float *)((const char *)src3 + blockIdx.y * splitD * src3_nb1);
+ const float *B_block = (const float *)((const char *)src4 + (blockIdx.x * src4_nb3));
+ const float *C_block = (const float *)((const char *)src5 + (blockIdx.x * src5_nb3));
+ float *y_block = (float *)((char *)dst + (blockIdx.x * d_inner * L * sizeof(float)) + blockIdx.y * splitD * sizeof(float));
+ float *s_block = (float *)((char *)dst + s_off + blockIdx.x * src0_nb3 + blockIdx.y * splitD * src0_nb2);
+
+ const int stride_x = src1_nb2 / sizeof(float);
+ const int stride_dt = src2_nb1 / sizeof(float);
+ const int stride_B = src4_nb2 / sizeof(float);
+ const int stride_C = src5_nb2 / sizeof(float);
+ const int stride_y = d_inner;
+
+ float regA[N];
+ float regs0[N];
+
+ __shared__ float smemB[N];
+ __shared__ float smemC[N];
+
+#ifdef USE_CUB
+ using BlockLoad = cub::BlockLoad<float, splitD, N, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
+ using BlockStore = cub::BlockStore<float, splitD, N, cub::BLOCK_STORE_WARP_TRANSPOSE>;
+
+ union CubTempStorage {
+ typename BlockLoad::TempStorage load_temp;
+ typename BlockStore::TempStorage store_temp;
+ };
+ __shared__ CubTempStorage cub_temp_storage;
+
+ BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA);
+ BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0);
+#else
+ const int stride_s0 = src0_nb2 / sizeof(float);
+ const int stride_A = src3_nb1 / sizeof(float);
+#pragma unroll
+ for (size_t n = 0; n < N; ++n)
+ {
+ regA[n] = A_block[threadIdx.x * stride_A + n];
+ regs0[n] = s0_block[threadIdx.x * stride_s0 + n];
+ }
+#endif
+
+#pragma unroll
+ for (size_t i = 0; i < L; i++)
+ {
+ if (threadIdx.x < N)
+ {
+ smemB[threadIdx.x] = B_block[i * stride_B + threadIdx.x];
+ smemC[threadIdx.x] = C_block[i * stride_C + threadIdx.x];
+ }
+ __syncthreads();
+
+ float dt_soft_plus = dt_block[i * stride_dt + threadIdx.x];
+ if (dt_soft_plus <= 20.0f)
+ {
+ dt_soft_plus = log1pf(expf(dt_soft_plus));
+ }
+ float x_dt = x_block[i * stride_x + threadIdx.x] * dt_soft_plus;
+
+ float sumf = 0.0f;
+#pragma unroll
+ for (size_t n = 0; n < N; n++)
+ {
+ float state = regs0[n] * expf(dt_soft_plus * regA[n]) + smemB[n] * x_dt;
+ sumf += state * smemC[n];
+ regs0[n] = state;
+ }
+ y_block[i * stride_y + threadIdx.x] = sumf;
+ }
+
+#ifdef USE_CUB
+ BlockStore(cub_temp_storage.store_temp).Store(s_block, regs0);
+#else
+ const int stride_s = stride_s0;
+#pragma unroll
+ for (size_t n = 0; n < N; ++n)
+ {
+ s_block[threadIdx.x * stride_s + n] = regs0[n];
+ }
+#endif
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+// assumes as many threads as d_state
+template <int c_factor, int d_state>
+__global__ void __launch_bounds__(d_state, 1)
+ ssm_scan_f32_group(
+ const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
+ const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
+ const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
+
+ const int warp = threadIdx.x / WARP_SIZE;
+ const int lane = threadIdx.x % WARP_SIZE;
+ const int warp_idx = blockIdx.x * c_factor + warp;
+
+ const int head_idx = warp_idx / d_head;
+ const int head_off = (warp_idx % d_head) * sizeof(float);
+ const int seq_idx = blockIdx.y;
+
+ const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
+
+ // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase
+ const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+ const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float)));
+ const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
+ const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1);
+ const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
+ const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
+ float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx;
+ float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
+
+ // strides across n_seq_tokens
+ const int stride_x = src1_nb2 / sizeof(float);
+ const int stride_dt = src2_nb1 / sizeof(float);
+ const int stride_B = src4_nb2 / sizeof(float);
+ const int stride_C = src5_nb2 / sizeof(float);
+ const int stride_y = n_head * d_head;
+
+ float state[c_factor];
+ float state_sum = 0.0f;
+
+#pragma unroll
+ for (int j = 0; j < c_factor; j++) {
+ state[j] = s0_warp[WARP_SIZE * j + lane];
+ }
+
+ for (int64_t i = 0; i < n_tok; i++) {
+ // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here.
+ // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead.
+ const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]);
+
+ state_sum = 0.0f;
+ const float dA = expf(dt_soft_plus * A_warp[0]);
+ const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
+#pragma unroll
+ for (int j = 0; j < c_factor; j++) {
+ const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane];
+ const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
+ state[j] = (state[j] * dA) + (B_val * x_dt);
+ state_sum += state[j] * C_val;
+ }
+
+ // parallel accumulation for output
+ state_sum = warp_reduce_sum(state_sum);
+
+ if (lane == 0) {
+ y_warp[i * stride_y] = state_sum;
+ }
+ }
+
+ // write back the state
+#pragma unroll
+ for (int j = 0; j < c_factor; j++) {
+ s_warp[WARP_SIZE * j + lane] = state[j];
+ }
+}
+
+static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
+ const float * src4, const float * src5, const int32_t * src6, float * dst,
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
+ const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
+ const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
+ const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
+ cudaStream_t stream) {
+ // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
+ if (src3_nb1 == sizeof(float)) {
+ // Mamba-2
+ if (d_state == 128) {
+ constexpr int threads = 128;
+ constexpr int num_warps = threads/WARP_SIZE;
+
+ const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+ ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+ } else if (d_state == 256) { // Falcon-H1
+ constexpr int threads = 256;
+ constexpr int num_warps = threads/WARP_SIZE;
+
+ const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1);
+ ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
+ } else {
+ GGML_ABORT("doesn't support d_state!=(128 or 256).");
+ }
+ } else {
+ // Mamba-1
+ constexpr int threads = 128;
+ GGML_ASSERT(n_head % threads == 0);
+ GGML_ASSERT(head_dim == 1);
+ GGML_ASSERT(n_group == 1);
+ const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
+ const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
+ if (d_state == 16) {
+ switch (n_tok)
+ {
+ case 1:
+ ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 2:
+ ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 3:
+ ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 4:
+ ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 5:
+ ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 6:
+ ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 7:
+ ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ case 8:
+ ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ default:
+ ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>(
+ src0, src1, src2, src3, src4, src5, src6, dst,
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
+ break;
+ }
+ } else {
+ GGML_ABORT("doesn't support d_state!=16.");
+ }
+ }
+}
+
+void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const struct ggml_tensor * src0 = dst->src[0]; // s
+ const struct ggml_tensor * src1 = dst->src[1]; // x
+ const struct ggml_tensor * src2 = dst->src[2]; // dt
+ const struct ggml_tensor * src3 = dst->src[3]; // A
+ const struct ggml_tensor * src4 = dst->src[4]; // B
+ const struct ggml_tensor * src5 = dst->src[5]; // C
+ const struct ggml_tensor * src6 = dst->src[6]; // ids
+
+ const int64_t nc = src0->ne[0]; // d_state
+ const int64_t nr = src0->ne[1]; // head_dim or 1
+ const int64_t nh = src1->ne[1]; // n_head
+ const int64_t ng = src4->ne[1]; // n_group
+ const int64_t n_t = src1->ne[2]; // number of tokens per sequence
+ const int64_t n_s = src1->ne[3]; // number of sequences in the batch
+
+ const int64_t s_off = ggml_nelements(src1) * sizeof(float);
+
+ GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src2->nb[0] == sizeof(float));
+ GGML_ASSERT(src3->nb[0] == sizeof(float));
+ GGML_ASSERT(src4->nb[0] == sizeof(float));
+ GGML_ASSERT(src5->nb[0] == sizeof(float));
+ GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ const float * src2_d = (const float *) src2->data;
+ const float * src3_d = (const float *) src3->data;
+ const float * src4_d = (const float *) src4->data;
+ const float * src5_d = (const float *) src5->data;
+ const int32_t * src6_d = (const int32_t *) src6->data;
+ float * dst_d = (float *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src6->type == GGML_TYPE_I32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
+ src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
+ src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
+ s_off, nc, nr, nh, ng, n_t, n_s, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cuh b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cuh
new file mode 100644
index 0000000..ee078f5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/ssm-scan.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/sum.cu b/llama.cpp/ggml/src/ggml-cuda/sum.cu
new file mode 100644
index 0000000..c56257b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/sum.cu
@@ -0,0 +1,41 @@
+#include "sum.cuh"
+#include "sumrows.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+#include <cub/cub.cuh>
+using namespace cub;
+#endif // GGML_CUDA_USE_CUB
+
+#include <cstdint>
+
+void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream) {
+#ifdef GGML_CUDA_USE_CUB
+ size_t tmp_size = 0;
+ DeviceReduce::Sum(nullptr, tmp_size, x, dst, ne, stream);
+ ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
+ DeviceReduce::Sum(tmp_alloc.ptr, tmp_size, x, dst, ne, stream);
+#else
+ // Use (inefficient) sum_rows implementation as a fallback.
+ // For AMD there is rocPRIM which could be used as a drop-in replacement via hipcub but this would require C++11 -> C++14.
+ sum_rows_f32_cuda(x, dst, ne, 1, stream);
+ GGML_UNUSED(pool);
+#endif // GGML_CUDA_USE_CUB
+}
+
+void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguously_allocated(src0));
+
+ const float * src0_d = (const float *) src0->data;
+ float * dst_d = (float *) dst->data;
+
+ const int64_t ne = ggml_nelements(src0);
+
+ ggml_cuda_pool & pool = ctx.pool();
+ cudaStream_t stream = ctx.stream();
+
+ sum_f32_cuda(pool, src0_d, dst_d, ne, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/sum.cuh b/llama.cpp/ggml/src/ggml-cuda/sum.cuh
new file mode 100644
index 0000000..8cadc37
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/sum.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+void sum_f32_cuda(ggml_cuda_pool & pool, const float * x, float * dst, const int64_t ne, cudaStream_t stream);
+
+void ggml_cuda_op_sum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/sumrows.cu b/llama.cpp/ggml/src/ggml-cuda/sumrows.cu
new file mode 100644
index 0000000..4025771
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/sumrows.cu
@@ -0,0 +1,43 @@
+#include "reduce_rows.cuh"
+#include "sumrows.cuh"
+
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+ const int id = ggml_cuda_get_device();
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+ const dim3 block_nums(nrows, 1, 1);
+ if ((nrows / nsm) < 2) {
+ const dim3 block_dims(512, 1, 1);
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+ } else {
+ const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
+ }
+}
+
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ const dim3 block_nums(nrows, 1, 1);
+
+ const int id = ggml_cuda_get_device();
+ const int nsm = ggml_cuda_info().devices[id].nsm;
+ if ((nrows / nsm) < 2) {
+ // Increase num threads to 512 for small nrows to better hide the latency
+ const dim3 block_dims(512, 1, 1);
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+ } else {
+ // Enough active SMs to hide latency, use smaller blocks to allow better scheduling
+ const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1);
+ reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh b/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh
new file mode 100644
index 0000000..3431c59
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/sumrows.cuh
@@ -0,0 +1,4 @@
+#include "common.cuh"
+
+void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
new file mode 100644
index 0000000..fb26abe
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu
new file mode 100644
index 0000000..1f554d8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
new file mode 100644
index 0000000..dc16829
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
new file mode 100644
index 0000000..9d3cfd8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
new file mode 100644
index 0000000..2e1883a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
new file mode 100644
index 0000000..517993c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -0,0 +1,11 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
new file mode 100644
index 0000000..f011a20
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu
new file mode 100644
index 0000000..264751d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
new file mode 100644
index 0000000..97b19c6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -0,0 +1,11 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
new file mode 100644
index 0000000..163b1d9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
new file mode 100644
index 0000000..0543532
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
new file mode 100644
index 0000000..407b6cf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
new file mode 100644
index 0000000..f5fd0e2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
new file mode 100644
index 0000000..5e46685
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
new file mode 100644
index 0000000..989626d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -0,0 +1,11 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
new file mode 100644
index 0000000..bad296b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
new file mode 100644
index 0000000..0d7a9c7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
new file mode 100644
index 0000000..9d5a997
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
new file mode 100644
index 0000000..a6e6f09
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
new file mode 100644
index 0000000..173de7a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -0,0 +1,11 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
new file mode 100644
index 0000000..680a13c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu
new file mode 100644
index 0000000..a8b15ad
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(112, 112);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu
new file mode 100644
index 0000000..1da1810
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(128, 128);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu
new file mode 100644
index 0000000..bc65c72
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(256, 256);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu
new file mode 100644
index 0000000..10b330f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(40, 40);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu
new file mode 100644
index 0000000..254b7d2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(576, 512);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu
new file mode 100644
index 0000000..5caffac
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(64, 64);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu
new file mode 100644
index 0000000..8f9d531
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(72, 72);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu
new file mode 100644
index 0000000..90abb3b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(80, 80);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu
new file mode 100644
index 0000000..7292c0a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(96, 96);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
new file mode 100644
index 0000000..c357abd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu
new file mode 100644
index 0000000..4b14865
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu
new file mode 100644
index 0000000..ef77157
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q4_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu
new file mode 100644
index 0000000..9ae11cc
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu
new file mode 100644
index 0000000..10ed48a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q5_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu
new file mode 100644
index 0000000..4fcc3f3
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-q8_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu
new file mode 100644
index 0000000..7ca5053
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-f16.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
new file mode 100644
index 0000000..6ef1a48
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu
new file mode 100644
index 0000000..4c0532c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu
new file mode 100644
index 0000000..ed3d7ba
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu
new file mode 100644
index 0000000..687f254
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q5_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu
new file mode 100644
index 0000000..41107c4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-q8_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu
new file mode 100644
index 0000000..d523ce0
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-f16.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu
new file mode 100644
index 0000000..8b9ed35
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu
new file mode 100644
index 0000000..0553e46
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q4_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu
new file mode 100644
index 0000000..8390eaf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu
new file mode 100644
index 0000000..f61e19d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q5_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu
new file mode 100644
index 0000000..86a1882
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-q8_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu
new file mode 100644
index 0000000..1d7af47
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-f16.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu
new file mode 100644
index 0000000..837224d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu
new file mode 100644
index 0000000..0dd7dd6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q4_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu
new file mode 100644
index 0000000..41b859f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu
new file mode 100644
index 0000000..d2e5ffd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q5_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu
new file mode 100644
index 0000000..81ff740
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-q8_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu
new file mode 100644
index 0000000..a38dae1
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-f16.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu
new file mode 100644
index 0000000..2304571
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu
new file mode 100644
index 0000000..84b83e5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q4_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu
new file mode 100644
index 0000000..39f80e2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu
new file mode 100644
index 0000000..cf4e661
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q5_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu
new file mode 100644
index 0000000..6565418
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-q8_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu
new file mode 100644
index 0000000..a1bc3f5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-f16.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu
new file mode 100644
index 0000000..4b76a9b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu
new file mode 100644
index 0000000..77d0412
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q4_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu
new file mode 100644
index 0000000..6e170fe
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu
new file mode 100644
index 0000000..b617cd7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q5_1.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
new file mode 100644
index 0000000..a5b768b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
new file mode 100755
index 0000000..e382df1
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/generate_cu_files.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python3
+
+from glob import glob
+import os
+
+HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
+
+TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
+
+SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
+"""
+
+SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.cuh"
+
+DECL_FATTN_VEC_CASE( 64, {type_k}, {type_v});
+DECL_FATTN_VEC_CASE(128, {type_k}, {type_v});
+DECL_FATTN_VEC_CASE(256, {type_k}, {type_v});
+"""
+
+SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+"""
+
+SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
+
+TYPES_MMQ = [
+ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
+ "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
+ "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
+ "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
+]
+
+SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE({type});
+"""
+
+SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE({type});
+"""
+
+
+def get_short_name(long_quant_name):
+ return long_quant_name.replace("GGML_TYPE_", "").lower()
+
+
+for filename in glob("*.cu"):
+ os.remove(filename)
+
+for head_size_kq in HEAD_SIZES_KQ:
+ head_size_v = head_size_kq if head_size_kq != 576 else 512
+ with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
+ f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
+
+for type_k in TYPES_KV:
+ for type_v in TYPES_KV:
+ with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
+ f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
+
+for ncols in [8, 16, 32, 64]:
+ for ncols2 in [1, 2, 4, 8, 16, 32]:
+ if ncols2 > ncols:
+ continue
+ ncols1 = ncols // ncols2
+ with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
+ f.write(SOURCE_FATTN_MMA_START)
+
+ for head_size_kq in HEAD_SIZES_KQ:
+ if head_size_kq == 40:
+ continue
+ if head_size_kq == 72:
+ continue
+ if head_size_kq != 576 and ncols2 in (16, 32):
+ continue
+ if head_size_kq == 576 and ncols2 not in (4, 16, 32):
+ continue
+ head_size_v = head_size_kq if head_size_kq != 576 else 512
+ f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
+
+for type in TYPES_MMQ:
+ with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
+ f.write(SOURCE_MMQ.format(type=type))
+
+for type in range(1, 17):
+ with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
+ f.write(SOURCE_MMF.format(type=type))
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu
new file mode 100644
index 0000000..f594d5d
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu
new file mode 100644
index 0000000..9cc6772
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(10);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu
new file mode 100644
index 0000000..317f487
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(11);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu
new file mode 100644
index 0000000..dc00332
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(12);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu
new file mode 100644
index 0000000..0782101
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(13);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu
new file mode 100644
index 0000000..a23ad6a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(14);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu
new file mode 100644
index 0000000..0fe3f78
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(15);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu
new file mode 100644
index 0000000..5440863
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(16);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu
new file mode 100644
index 0000000..3b90179
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(2);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu
new file mode 100644
index 0000000..56e940b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(3);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu
new file mode 100644
index 0000000..a7665d4
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(4);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu
new file mode 100644
index 0000000..3a1dff2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(5);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu
new file mode 100644
index 0000000..400fb7c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(6);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu
new file mode 100644
index 0000000..954a1c7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(7);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu
new file mode 100644
index 0000000..f1bd09c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(8);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu
new file mode 100644
index 0000000..1255ac2
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(9);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu
new file mode 100644
index 0000000..84ec850
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu
new file mode 100644
index 0000000..583c4e5
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu
new file mode 100644
index 0000000..edaf156
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu
new file mode 100644
index 0000000..233d934
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu
new file mode 100644
index 0000000..6092dc7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu
new file mode 100644
index 0000000..1d5bd20
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
new file mode 100644
index 0000000..eb02fab
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
new file mode 100644
index 0000000..1eb3b74
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu
new file mode 100644
index 0000000..c14624c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_MXFP4);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu
new file mode 100644
index 0000000..6415369
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q2_K);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu
new file mode 100644
index 0000000..ffb6213
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q3_K);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu
new file mode 100644
index 0000000..0c0b0c8
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu
new file mode 100644
index 0000000..ee67f69
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu
new file mode 100644
index 0000000..9eeb3cd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q4_K);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu
new file mode 100644
index 0000000..cc57fb9
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu
new file mode 100644
index 0000000..721ac79
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_1);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu
new file mode 100644
index 0000000..a2e90ff
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q5_K);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu
new file mode 100644
index 0000000..470938f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q6_K);
diff --git a/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu
new file mode 100644
index 0000000..974477b
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_Q8_0);
diff --git a/llama.cpp/ggml/src/ggml-cuda/top-k.cu b/llama.cpp/ggml/src/ggml-cuda/top-k.cu
new file mode 100644
index 0000000..785a183
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/top-k.cu
@@ -0,0 +1,95 @@
+#include "argsort.cuh"
+#include "top-k.cuh"
+
+#ifdef GGML_CUDA_USE_CUB
+# include <cub/cub.cuh>
+# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
+# define CUB_TOP_K_AVAILABLE
+using namespace cub;
+# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
+#endif // GGML_CUDA_USE_CUB
+
+#ifdef CUB_TOP_K_AVAILABLE
+
+static void top_k_cub(ggml_cuda_pool & pool,
+ const float * src,
+ int * dst,
+ const int ncols,
+ const int k,
+ cudaStream_t stream) {
+ auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
+ cuda::execution::output_ordering::unsorted);
+ auto stream_env = cuda::stream_ref{ stream };
+ auto env = cuda::std::execution::env{ stream_env, requirements };
+
+ auto indexes_in = cuda::make_counting_iterator(0);
+
+ size_t temp_storage_bytes = 0;
+ DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
+ env);
+
+ ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
+ void * d_temp_storage = temp_storage_alloc.get();
+
+ DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
+ ncols, k, env);
+}
+
+#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
+
+static int next_power_of_2(int x) {
+ int n = 1;
+ while (n < x) {
+ n *= 2;
+ }
+ return n;
+}
+
+#endif // CUB_TOP_K_AVAILABLE
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *) src0->data;
+ int * dst_d = (int *) dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ // are these asserts truly necessary?
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ const int64_t ncols = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+ const int64_t k = dst->ne[0];
+ ggml_cuda_pool & pool = ctx.pool();
+#ifdef CUB_TOP_K_AVAILABLE
+ // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
+ // https://github.com/NVIDIA/cccl/issues/6391
+ // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
+ for (int i = 0; i < nrows; i++) {
+ top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
+ }
+#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
+ // Fall back to argsort + copy
+ const int ncols_pad = next_power_of_2(ncols);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
+ int * tmp_dst = temp_dst_alloc.get();
+
+ if (shared_mem > max_shared_mem || ncols > 1024) {
+ argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+ } else {
+ argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+ }
+ CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+ cudaMemcpyDeviceToDevice, stream));
+#else // GGML_CUDA_USE_CUB
+ ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
+ int * tmp_dst = temp_dst_alloc.get();
+ argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
+ CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
+ cudaMemcpyDeviceToDevice, stream));
+#endif
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/top-k.cuh b/llama.cpp/ggml/src/ggml-cuda/top-k.cuh
new file mode 100644
index 0000000..f4d8f61
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/top-k.cuh
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu
new file mode 100644
index 0000000..08a8899
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cu
@@ -0,0 +1,403 @@
+#include "ggml-cuda/common.cuh"
+#include "ggml.h"
+#include "topk-moe.cuh"
+
+#include <cmath>
+#include <initializer_list>
+
+// Kernel config struct - passed by value to CUDA kernel
+struct topk_moe_config {
+ bool use_sigmoid;
+ bool with_norm;
+ bool delayed_softmax;
+};
+
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+template <int experts_per_thread, bool use_limit>
+__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+ float max_val = -INFINITY;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ max_val = max(max_val, vals[i]);
+ }
+ }
+
+ max_val = warp_reduce_max(max_val);
+
+ float sum = 0.f;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ const float val = expf(vals[i] - max_val);
+ vals[i] = val;
+ sum += val;
+ } else {
+ vals[i] = 0.f;
+ }
+ }
+
+ sum = warp_reduce_sum(sum);
+
+ const float inv_sum = 1.0f / sum;
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ if (active) {
+ vals[i] *= inv_sum;
+ }
+ }
+}
+
+template <int experts_per_thread, bool use_limit>
+__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = lane + i * WARP_SIZE;
+ const bool active = !use_limit || (idx < limit);
+ vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
+ }
+}
+
+/*
+ This kernel does the following:
+ 1. optionally softmax over the logits per token [n_experts, n_tokens]
+ 2. argmax reduce over the top-k (n_experts_used) logits
+ 3. write weights + ids to global memory
+ 4. optionally normalize the weights or apply softmax over the selected logits
+
+ It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
+*/
+template <int n_experts, bool has_bias>
+__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
+ float * weights,
+ int32_t * ids,
+ float * bias,
+ const int n_rows,
+ const int n_expert_used,
+ const float clamp_val,
+ const float scale_val,
+ const topk_moe_config config) {
+ const int row = blockIdx.x * blockDim.y + threadIdx.y;
+ if (row >= n_rows) {
+ return;
+ }
+
+ logits += n_experts * row;
+ weights += n_expert_used * row;
+ ids += n_experts * row;
+
+ constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
+
+ float wt[experts_per_thread];
+
+ // Initialize all slots to -INFINITY
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ wt[i] = -INFINITY;
+ }
+
+#pragma unroll
+ for (int i = 0; i < n_experts; i += WARP_SIZE) {
+ const int expert = i + threadIdx.x;
+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
+ }
+
+ if (!config.delayed_softmax) {
+ if (config.use_sigmoid) {
+ sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
+ } else {
+ softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
+ }
+ }
+
+ // selection_wt is only needed when bias is present (selection uses wt + bias)
+ // when no bias, we use wt directly for both selection and weight values
+ float selection_wt[has_bias ? experts_per_thread : 1];
+
+ if constexpr (has_bias) {
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ selection_wt[i] = -INFINITY;
+ }
+#pragma unroll
+ for (int i = 0; i < n_experts; i += WARP_SIZE) {
+ const int expert = i + threadIdx.x;
+ selection_wt[i / WARP_SIZE] =
+ (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
+ }
+ }
+
+ //at this point, each thread holds either a portion of the softmax distribution
+ //or the raw logits. We do the argmax reduce over n_expert_used, each time marking
+ //the expert weight as -inf to exclude from the next iteration
+
+ float wt_sum = 0.f;
+
+ float output_weights[experts_per_thread];
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] = 0.f;
+ }
+
+ for (int k = 0; k < n_expert_used; k++) {
+ float max_val = wt[0];
+ int max_expert = threadIdx.x;
+
+ if constexpr (has_bias) {
+ float max_val_s = selection_wt[0];
+
+#pragma unroll
+ for (int i = 1; i < experts_per_thread; i++) {
+ const int expert = threadIdx.x + i * WARP_SIZE;
+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
+ max_val = wt[i];
+ max_val_s = selection_wt[i];
+ max_expert = expert;
+ }
+ }
+
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+ const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+ const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
+ const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+ if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
+ max_val = val;
+ max_val_s = val_s;
+ max_expert = expert;
+ }
+ }
+
+ if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+ selection_wt[max_expert / WARP_SIZE] = -INFINITY;
+ }
+ } else {
+#pragma unroll
+ for (int i = 1; i < experts_per_thread; i++) {
+ const int expert = threadIdx.x + i * WARP_SIZE;
+ if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
+ max_val = wt[i];
+ max_expert = expert;
+ }
+ }
+
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+ const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
+ const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
+ if (val > max_val || (val == max_val && expert < max_expert)) {
+ max_val = val;
+ max_expert = expert;
+ }
+ }
+
+ if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+ wt[max_expert / WARP_SIZE] = -INFINITY;
+ }
+ }
+
+ if ((k & (WARP_SIZE - 1)) == threadIdx.x) {
+ output_weights[k / WARP_SIZE] = max_val;
+ }
+
+ if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
+ ids[k] = max_expert;
+ if (config.with_norm) {
+ wt_sum += max_val;
+ }
+ }
+ }
+
+ if (config.with_norm) {
+ wt_sum = warp_reduce_sum(wt_sum);
+ wt_sum = max(wt_sum, clamp_val);
+ const float inv_sum = 1.0f / wt_sum;
+
+ for (int i = 0; i < experts_per_thread; i++) {
+ output_weights[i] *= inv_sum;
+ }
+ }
+
+ if (config.delayed_softmax) {
+ softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
+ }
+
+#pragma unroll
+ for (int i = 0; i < experts_per_thread; i++) {
+ const int idx = i * WARP_SIZE + threadIdx.x;
+ if (idx < n_expert_used) {
+ weights[idx] = output_weights[i] * scale_val;
+ }
+ }
+}
+
+template<bool has_bias>
+static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
+ const float * logits,
+ float * weights,
+ int32_t * ids,
+ float * bias,
+ const int n_rows,
+ const int n_expert,
+ const int n_expert_used,
+ const float clamp_val,
+ const float scale_val,
+ const topk_moe_config config) {
+ GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
+ "delayed softmax is not supported with weight normalization");
+ const int rows_per_block = 4;
+ dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
+ dim3 block_dims(WARP_SIZE, rows_per_block, 1);
+ cudaStream_t stream = ctx.stream();
+
+ switch (n_expert) {
+ case 1:
+ topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 2:
+ topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 4:
+ topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 8:
+ topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 16:
+ topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 32:
+ topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 64:
+ topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 128:
+ topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 256:
+ topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 512:
+ topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ case 576:
+ topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
+ clamp_val, scale_val, config);
+ break;
+ default:
+ GGML_ASSERT(false && "fatal error");
+ break;
+ }
+}
+
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
+ const ggml_tensor * logits,
+ ggml_tensor * weights,
+ ggml_tensor * ids,
+ const ggml_tensor * clamp,
+ const ggml_tensor * scale,
+ const ggml_tensor * bias,
+ const ggml_cuda_topk_moe_args & args) {
+ GGML_ASSERT(logits->type == GGML_TYPE_F32);
+ GGML_ASSERT(weights->type == GGML_TYPE_F32);
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ const int n_experts = logits->ne[0];
+ const int n_rows = logits->ne[1];
+
+ const float * logits_d = (const float *) logits->data;
+ float * weights_d = (float *) weights->data;
+ int32_t * ids_d = (int32_t *) ids->data;
+ float * bias_d = bias ? (float *) bias->data : nullptr;
+
+ float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
+
+ GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
+
+ const int n_expert_used = weights->ne[1];
+
+ const bool with_norm = clamp != nullptr;
+
+ float clamp_val = -INFINITY;
+ if (clamp) {
+ clamp_val = ggml_get_op_params_f32(clamp, 0);
+ }
+
+ topk_moe_config config;
+ config.use_sigmoid = args.sigmoid;
+ config.with_norm = with_norm;
+ config.delayed_softmax = args.delayed_softmax;
+
+ if (bias) {
+ launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+ scale_val, config);
+ } else {
+ launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
+ scale_val, config);
+ }
+}
+
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
+ const ggml_tensor * weights,
+ const ggml_tensor * logits,
+ const ggml_tensor * ids) {
+ const int n_expert = ids->nb[1] / ids->nb[0];
+ if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
+ return false;
+ }
+
+ if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
+ return false;
+ }
+
+ if (gating_op->op == GGML_OP_SOFT_MAX) {
+ const ggml_tensor * softmax = gating_op;
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+
+ memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
+
+ if (!ggml_is_contiguous(softmax->src[0])) {
+ return false;
+ }
+
+ if (scale != 1.0f || max_bias != 0.0f) {
+ return false;
+ }
+
+ // don't fuse when masks or sinks are present
+ if (softmax->src[1] || softmax->src[2]) {
+ return false;
+ }
+ } else if (gating_op->op == GGML_OP_UNARY) {
+ ggml_unary_op op = ggml_get_unary_op(gating_op);
+
+ if (op != GGML_UNARY_OP_SIGMOID) {
+ return false;
+ }
+ }
+
+ return true;
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh
new file mode 100644
index 0000000..243dc2f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/topk-moe.cuh
@@ -0,0 +1,27 @@
+#include "common.cuh"
+#include "ggml.h"
+
+#include <initializer_list>
+
+struct ggml_cuda_topk_moe_args {
+ bool sigmoid{};
+ bool softmax{};
+ bool delayed_softmax{};
+ bool prob_bias{};
+ bool norm{};
+ bool scale{};
+};
+
+void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
+ const ggml_tensor * logits,
+ ggml_tensor * weights,
+ ggml_tensor * ids,
+ const ggml_tensor * clamp,
+ const ggml_tensor * scale,
+ const ggml_tensor * bias,
+ const ggml_cuda_topk_moe_args & args);
+
+bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
+ const ggml_tensor * weights,
+ const ggml_tensor * logits,
+ const ggml_tensor * ids);
diff --git a/llama.cpp/ggml/src/ggml-cuda/tri.cu b/llama.cpp/ggml/src/ggml-cuda/tri.cu
new file mode 100644
index 0000000..44156b6
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/tri.cu
@@ -0,0 +1,136 @@
+#include "common.cuh"
+#include "convert.cuh"
+#include "tri.cuh"
+#include "ggml.h"
+
+template<typename T, bool prefix_keep, int add_to_split>
+static __global__ void tri_kernel(
+ const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
+ const int64_t i3 = blockIdx.z;
+ const int64_t i2 = blockIdx.y;
+ const int64_t i1 = blockIdx.x;
+ const int64_t split_point = i1 + add_to_split;
+
+ GGML_UNUSED_VARS(nb00, nb0);
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
+ }
+
+ const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
+ T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3;
+
+ if constexpr (prefix_keep) {
+ for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
+ dst_row[i0] = src_row[i0];
+ }
+ for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
+ dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
+ }
+ } else {
+ for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
+ dst_row[i0] = ggml_cuda_cast<T, float>(0.0f);
+ }
+ for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
+ dst_row[i0] = src_row[i0];
+ }
+ }
+}
+
+template<typename T>
+static void tri_cuda(
+ const T * src, T * dst,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+ const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
+ const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
+ const ggml_tri_type ttype,
+ cudaStream_t stream) {
+
+ dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
+ dim3 grid_dims(ne01, ne02, ne03);
+ const size_t type_size = sizeof(T);
+
+ const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
+ const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);
+
+ if (prefix_keep) {
+ if (add_to_split == 0) {
+ tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ } else { // only 0 and 1 supported
+ tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ }
+ } else {
+ if (add_to_split == 0) {
+ tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ } else {
+ tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
+ src, dst,
+ ne00, ne01, ne02, ne03,
+ nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
+ nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
+ );
+ }
+ }
+}
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ cudaStream_t stream = ctx.stream();
+
+ const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));
+
+ GGML_ASSERT(src0->type == dst->type);
+
+ switch(src0->type) {
+ case GGML_TYPE_F32:
+ {
+ tri_cuda(
+ (const float *)src0->data, (float *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ ttype, stream
+ );
+ } break;
+ case GGML_TYPE_F16:
+ {
+ tri_cuda(
+ (const half *)src0->data, (half *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ ttype, stream
+ );
+ } break;
+ case GGML_TYPE_BF16:
+ {
+ tri_cuda(
+ (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
+ src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
+ src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
+ ttype, stream
+ );
+ } break;
+ default:
+ GGML_ABORT("fatal error");
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/tri.cuh b/llama.cpp/ggml/src/ggml-cuda/tri.cuh
new file mode 100644
index 0000000..a4cc667
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/tri.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_TRI_BLOCK_SIZE 256
+
+void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/tsembd.cu b/llama.cpp/ggml/src/ggml-cuda/tsembd.cu
new file mode 100644
index 0000000..b91a26f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/tsembd.cu
@@ -0,0 +1,47 @@
+#include "tsembd.cuh"
+
+static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
+ // blockIDx.y: idx of timesteps->ne[0]
+ // blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
+ int i = blockIdx.y;
+ int j = threadIdx.x + blockIdx.x * blockDim.x;
+ float * embed_data = (float *)((char *)dst + i*nb1);
+
+ int half = dim / 2;
+ if (dim % 2 != 0 && j == half) {
+ embed_data[2 * half] = 0.f;
+ }
+
+ if (j >= half) {
+ return;
+ }
+
+ float timestep = timesteps[i];
+ float freq = (float)expf(-logf(max_period) * j / half);
+ float arg = timestep * freq;
+ embed_data[j] = cosf(arg);
+ embed_data[j + half] = sinf(arg);
+}
+
+static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
+ const int dim, const int max_period, cudaStream_t stream) {
+ int half_ceil = (dim + 1) / 2;
+ int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne00, 1);
+ timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
+}
+
+void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const int dim = dst->op_params[0];
+ const int max_period = dst->op_params[1];
+
+ timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/tsembd.cuh b/llama.cpp/ggml/src/ggml-cuda/tsembd.cuh
new file mode 100644
index 0000000..84340e3
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/tsembd.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
+
+void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/unary.cu b/llama.cpp/ggml/src/ggml-cuda/unary.cu
new file mode 100644
index 0000000..d486606
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/unary.cu
@@ -0,0 +1,562 @@
+#include "unary.cuh"
+#include "convert.cuh"
+
+static __device__ __forceinline__ float op_abs(float x) {
+ return fabsf(x);
+}
+
+static __device__ __forceinline__ float op_sgn(float x) {
+ return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f)));
+}
+
+static __device__ __forceinline__ float op_neg(float x) {
+ return -x;
+}
+
+static __device__ __forceinline__ float op_step(float x) {
+ return x > 0.0f;
+}
+
+static __device__ __forceinline__ float op_gelu(float x) {
+ return ggml_cuda_op_gelu_single(x);
+}
+
+static __device__ __forceinline__ float op_gelu_erf(float x) {
+ const float SQRT_2_INV = 0.70710678118654752440084436210484f;
+
+ return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
+}
+
+static __device__ __forceinline__ float op_gelu_quick(float x) {
+ const float GELU_QUICK_COEF = -1.702f;
+
+ return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x)));
+}
+
+static __device__ __forceinline__ float op_silu(float x) {
+ return ggml_cuda_op_silu_single(x);
+}
+
+static __device__ __forceinline__ float op_tanh(float x) {
+ return tanhf(x);
+}
+
+static __device__ __forceinline__ float op_relu(float x) {
+ return fmaxf(x, 0);
+}
+
+static __device__ __forceinline__ float op_sigmoid(float x) {
+ return 1.0f / (1.0f + expf(-x));
+}
+
+static __device__ __forceinline__ float op_hardsigmoid(float x) {
+ return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
+}
+
+static __device__ __forceinline__ float op_hardswish(float x) {
+ return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
+}
+
+static __device__ __forceinline__ float op_exp(float x) {
+ return expf(x);
+}
+
+static __device__ __forceinline__ float op_sqr(float x) {
+ return x * x;
+}
+
+static __device__ __forceinline__ float op_sqrt(float x) {
+ return sqrtf(x);
+}
+
+static __device__ __forceinline__ float op_sin(float x) {
+ return sinf(x);
+}
+
+static __device__ __forceinline__ float op_cos(float x) {
+ return cosf(x);
+}
+
+static __device__ __forceinline__ float op_log(float x) {
+ return logf(x);
+}
+
+static __device__ __forceinline__ float op_expm1(float x) {
+ return expm1f(x);
+}
+
+static __device__ __forceinline__ float op_softplus(float x) {
+ return (x > 20.0f) ? x : logf(1.0f + expf(x));
+}
+
+static __device__ __forceinline__ float op_elu(float x) {
+ return (x > 0.f) ? x : expm1f(x);
+}
+
+static __device__ __forceinline__ float op_floor(float x) {
+ return floorf(x);
+}
+
+static __device__ __forceinline__ float op_ceil(float x) {
+ return ceilf(x);
+}
+
+static __device__ __forceinline__ float op_round(float x) {
+ return round(x);
+}
+
+static __device__ __forceinline__ float op_trunc(float x) {
+ return trunc(x);
+}
+
+template <float (*op)(float), typename T>
+static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = (T)op((float)x[i]);
+}
+
+template <float (*op)(float), typename T>
+static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
+ unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+template <float (*op)(float)>
+void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const void * src0_d = src0->data;
+ void * dst_d = dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ if (src0->type == GGML_TYPE_F16) {
+ unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
+ } else {
+ unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
+ }
+}
+
+void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_abs>(ctx, dst);
+}
+
+void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_sgn>(ctx, dst);
+}
+
+void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_neg>(ctx, dst);
+}
+
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_step>(ctx, dst);
+}
+
+void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_gelu>(ctx, dst);
+}
+
+void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
+}
+
+void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
+}
+
+void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_silu>(ctx, dst);
+}
+
+void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_tanh>(ctx, dst);
+}
+
+void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_relu>(ctx, dst);
+}
+
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
+}
+
+void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst);
+}
+
+void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_hardswish>(ctx, dst);
+}
+
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_exp>(ctx, dst);
+}
+
+void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_sqr>(ctx, dst);
+}
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_sqrt>(ctx, dst);
+}
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_sin>(ctx, dst);
+}
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_cos>(ctx, dst);
+}
+
+void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_log>(ctx, dst);
+}
+
+void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_elu>(ctx, dst);
+}
+
+void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_floor>(ctx, dst);
+}
+
+void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_ceil>(ctx, dst);
+}
+
+void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_round>(ctx, dst);
+}
+
+void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_trunc>(ctx, dst);
+}
+
+void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_expm1>(ctx, dst);
+}
+
+void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary<op_softplus>(ctx, dst);
+}
+/* gated ops */
+
+template <float (*op)(float), typename T>
+static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
+ const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ // perform base op and multiply with gate (either offset in same tensor or a separate one)
+ const int64_t j0 = (i / n) * o0 + (i % n);
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+ dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
+}
+
+template <float (*op)(float), typename T>
+static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
+ const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
+ unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
+}
+
+template <float (*op)(float)>
+void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ void * src0_d = src0->data;
+ void * src1_d = src1 ? src1->data : src0->data;
+ const int64_t src0_o = src0->nb[1];
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+ void * dst_d = dst->data;
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+ GGML_ASSERT(dst->ne[0] == nc);
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+ if (src1) {
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+ GGML_ASSERT(src1->ne[0] == nc);
+ GGML_ASSERT(src0->type == src1->type);
+ }
+
+ const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+
+ if (src0->type == GGML_TYPE_F16) {
+ half * src0_p = (half *) src0_d;
+ half * src1_p = (half *) src1_d;
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
+ } else {
+ float * src0_p = (float *) src0_d;
+ float * src1_p = (float *) src1_d;
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
+ }
+}
+
+void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
+}
+
+void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
+}
+
+void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
+}
+
+void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
+}
+
+void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
+}
+
+// swiglu_oai
+
+template <typename T>
+static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
+ const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ // perform base op and multiply with gate (either offset in same tensor or a separate one)
+ const int64_t j0 = (i / n) * o0 + (i % n);
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+ float xi = x[j0];
+ float gi = g[j1];
+
+ dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
+}
+
+template <typename T>
+static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
+ const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
+ swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
+}
+
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ void * src0_d = src0->data;
+ void * src1_d = src1 ? src1->data : src0->data;
+ const int64_t src0_o = src0->nb[1];
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+ void * dst_d = dst->data;
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == dst->type);
+ GGML_ASSERT(dst->ne[0] == nc);
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+ if (src1) {
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+ GGML_ASSERT(src1->ne[0] == nc);
+ GGML_ASSERT(src0->type == src1->type);
+ }
+
+ //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
+
+ float * src0_p = (float *) src0_d;
+ float * src1_p = (float *) src1_d;
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
+}
+
+/* CUDA kernel + launcher for xIELU */
+
+template <typename T>
+static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const float xi = ggml_cuda_cast<float>(x[i]);
+
+ const float gate_pos = (xi > 0.0f);
+ const float y_pos = alpha_p * xi * xi + beta * xi;
+ const float min_v_eps = fminf(xi, eps);
+ const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
+ const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
+
+ dst[i] = ggml_cuda_cast<T>(out);
+}
+
+template <typename T>
+static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;
+ xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps);
+}
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const void * src0_d = src0->data;
+ void * dst_d = dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ const float alpha_n = ggml_get_op_params_f32(dst, 1);
+ const float alpha_p = ggml_get_op_params_f32(dst, 2);
+ const float beta = ggml_get_op_params_f32(dst, 3);
+ const float eps = ggml_get_op_params_f32(dst, 4);
+
+ if (src0->type == GGML_TYPE_F16) {
+ xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+ } else {
+ xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+ }
+}
+
+
+
+/* silu_back */
+
+static __device__ __forceinline__ float op_silu_back(float grad, float x) {
+ const float s = 1.0f / (1.0f + expf(-x));
+ return grad * s * (1.0f + x * (1.0f - s));
+}
+
+template <class T>
+static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]);
+}
+
+template <class T>
+static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
+ silu_back_kernel<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
+}
+
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0]; // input from forward pass
+ const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
+
+ const float * src0_d = (const float *) src0->data;
+ const float * src1_d = (const float *) src1->data;
+ float * dst_d = (float *) dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ if (src0->type == GGML_TYPE_F16) {
+ silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream);
+ } else {
+ silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream);
+ }
+}
+
+/* leaky relu */
+
+static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) {
+ return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope;
+}
+
+template <class T>
+static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ dst[i] = (T)op_leaky_relu((float)x[i], negative_slope);
+}
+
+template <class T>
+static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+ leaky_relu_kernel<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
+}
+
+void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const void * src0_d = src0->data;
+ void * dst_d = dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ float negative_slope;
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+ if (src0->type == GGML_TYPE_F16) {
+ leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream);
+ } else {
+ leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/unary.cuh b/llama.cpp/ggml/src/ggml-cuda/unary.cuh
new file mode 100644
index 0000000..609046e
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/unary.cuh
@@ -0,0 +1,110 @@
+#pragma once
+#include "common.cuh"
+
+#define CUDA_NEG_BLOCK_SIZE 256
+#define CUDA_STEP_BLOCK_SIZE 256
+#define CUDA_GELU_BLOCK_SIZE 256
+#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_SILU_BACK_BLOCK_SIZE 256
+#define CUDA_TANH_BLOCK_SIZE 256
+#define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SIGMOID_BLOCK_SIZE 256
+#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
+#define CUDA_EXP_BLOCK_SIZE 256
+#define CUDA_HARDSWISH_BLOCK_SIZE 256
+#define CUDA_SQR_BLOCK_SIZE 256
+#define CUDA_SQRT_BLOCK_SIZE 256
+#define CUDA_SIN_BLOCK_SIZE 256
+#define CUDA_COS_BLOCK_SIZE 256
+#define CUDA_GLU_BLOCK_SIZE 256
+#define CUDA_XIELU_BLOCK_SIZE 256
+
+void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
+ return x / (1.0f + expf(-x));
+}
+
+__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
+ const float GELU_COEF_A = 0.044715f;
+ const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+ return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
+}
+
+__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
+ x = fminf(x, limit);
+ g = fmaxf(fminf(g, limit), -limit);
+
+ float out_glu = x / (1.0f + expf(-x * alpha));
+ out_glu = out_glu * (1.0f + g);
+ return out_glu;
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/upscale.cu b/llama.cpp/ggml/src/ggml-cuda/upscale.cu
new file mode 100644
index 0000000..6bdf3cd
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/upscale.cu
@@ -0,0 +1,293 @@
+#include "upscale.cuh"
+
+static __global__ void upscale_f32(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ const float sf0, const float sf1, const float sf2, const float sf3) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index >= ne10 * ne11 * ne12 * ne13) {
+ return;
+ }
+
+ int i10 = index % ne10;
+ int i11 = (index / ne10) % ne11;
+ int i12 = (index / (ne10 * ne11)) % ne12;
+ int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+
+ int i00 = i10 / sf0;
+ int i01 = i11 / sf1;
+ int i02 = i12 / sf2;
+ int i03 = i13 / sf3;
+
+ dst[index] = *( (const float *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00) );
+}
+
+static __global__ void upscale_f32_bilinear(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne00_src, const int ne01_src,
+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+ const float sf0, const float sf1, const float sf2, const float sf3,
+ const float pixel_offset) {
+ const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
+ const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+
+ if (index >= dst_total_elements) {
+ return;
+ }
+
+ const int i10_dst = index % ne10_dst;
+ const int i11_dst = (index / ne10_dst) % ne11_dst;
+ const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
+ const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
+
+ const int i02_src = (int)(i12_dst / sf2);
+ const int i03_src = (int)(i13_dst / sf3);
+
+ const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
+ int y0_src = (int)floorf(y_src_f);
+ int y1_src = y0_src + 1;
+
+ y0_src = max(0, min(y0_src, ne01_src - 1));
+ y1_src = max(0, min(y1_src, ne01_src - 1));
+
+ float dy = y_src_f - (float)y0_src;
+ dy = max(0.0f, min(dy, 1.0f));
+
+ float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
+ int x0_src = (int)floorf(x_src_f);
+ int x1_src = x0_src + 1;
+
+ x0_src = max(0, min(x0_src, ne00_src - 1));
+ x1_src = max(0, min(x1_src, ne00_src - 1));
+
+ float dx = x_src_f - (float)x0_src;
+ dx = max(0.0f, min(dx, 1.0f));
+
+ const float * p_a = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+ const float * p_b = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+ const float * p_c = (const float *)((const char *)x + (int64_t)x0_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+ const float * p_d = (const float *)((const char *)x + (int64_t)x1_src * nb00 + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03);
+
+ const float val_a = *p_a;
+ const float val_b = *p_b;
+ const float val_c = *p_c;
+ const float val_d = *p_d;
+
+ float result = val_a * (1.0f - dx) * (1.0f - dy) +
+ val_b * dx * (1.0f - dy) +
+ val_c * (1.0f - dx) * dy +
+ val_d * dx * dy;
+
+ dst[index] = result;
+}
+
+// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
+// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
+static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne00_src, const int ne01_src,
+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+ const float sf0, const float sf1, const float sf2, const float sf3,
+ const float pixel_offset) {
+ const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
+ const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+
+ if (index >= dst_total_elements) {
+ return;
+ }
+
+ const int i10_dst = index % ne10_dst;
+ const int i11_dst = (index / ne10_dst) % ne11_dst;
+ const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
+ const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
+
+ const int i02_src = (int)(i12_dst / sf2);
+ const int i03_src = (int)(i13_dst / sf3);
+
+ const float y = ((float)i11_dst + pixel_offset) / sf1;
+ const float x = ((float)i10_dst + pixel_offset) / sf0;
+
+ // support and invscale, minimum 1 pixel for bilinear
+ const float support1 = max(1.0f / sf1, 1.0f);
+ const float invscale1 = 1.0f / support1;
+ const float support0 = max(1.0f / sf0, 1.0f);
+ const float invscale0 = 1.0f / support0;
+
+ // the range of source pixels that contribute
+ const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset));
+ const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset));
+ const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset));
+ const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset));
+
+ // bilinear filter with antialiasing
+ float val = 0.0f;
+ float total_weight = 0.0f;
+
+ auto triangle_filter = [](float x) -> float {
+ return max(1.0f - fabsf(x), 0.0f);
+ };
+
+ for (int64_t sy = y_min; sy < y_max; sy++) {
+ const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
+
+ for (int64_t sx = x_min; sx < x_max; sx++) {
+ const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
+ const float weight = weight_x * weight_y;
+
+ if (weight <= 0.0f) {
+ continue;
+ }
+
+ const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03);
+ val += pixel * weight;
+ total_weight += weight;
+ }
+ }
+
+ if (total_weight > 0.0f) {
+ val /= total_weight;
+ }
+
+ dst[index] = val;
+}
+
+namespace bicubic_interpolation {
+// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
+__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
+
+static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
+static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
+
+static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
+ const float w0 = weight2(x + 1);
+ const float w1 = weight1(x + 0);
+ const float w2 = weight1(1 - x);
+ const float w3 = weight2(2 - x);
+ return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
+};
+} // namespace bicubic_interpolation
+
+static __global__ void upscale_f32_bicubic(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne00_src, const int ne01_src,
+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+ const float sf0, const float sf1, const float sf2, const float sf3,
+ const float pixel_offset) {
+ using bicubic_interpolation::bicubic;
+
+ const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
+ const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+
+ if (index >= dst_total_elements) {
+ return;
+ }
+
+ const int i10_dst = index % ne10_dst;
+ const int i11_dst = (index / ne10_dst) % ne11_dst;
+ const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
+ const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
+
+ const int i02_src = (int)(i12_dst / sf2);
+ const int i03_src = (int)(i13_dst / sf3);
+
+ const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
+ const int y0_src = (int)floorf(y_src_f);
+ const float dy = y_src_f - (float)y0_src;
+
+ const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
+ const int x0_src = (int)floorf(x_src_f);
+ const float dx = x_src_f - (float)x0_src;
+
+ const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
+
+ auto load = [=](int x_off, int y_off) -> float {
+ int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
+ int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
+ return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
+ };
+
+ const float result = bicubic(
+ bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
+ bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
+ bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
+ bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
+
+ dst[index] = result;
+}
+
+static void upscale_f32_cuda(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne10, const int ne11, const int ne12, const int ne13,
+ const float sf0, const float sf1, const float sf2, const float sf3,
+ cudaStream_t stream) {
+ const int64_t dst_size = ne10 * ne11 * ne12 * ne13;
+ const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+
+ upscale_f32<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3);
+}
+
+static void upscale_f32_bilinear_cuda(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne00_src, const int ne01_src,
+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+ const float sf0, const float sf1, const float sf2, const float sf3,
+ const float pixel_offset, bool antialias, cudaStream_t stream) {
+ const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+ const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+
+ if (antialias) {
+ upscale_f32_bilinear_antialias<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
+ } else {
+ upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
+ }
+}
+
+static void upscale_f32_bicubic_cuda(const float * x, float * dst,
+ const int nb00, const int nb01, const int nb02, const int nb03,
+ const int ne00_src, const int ne01_src,
+ const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
+ const float sf0, const float sf1, const float sf2, const float sf3,
+ const float pixel_offset, cudaStream_t stream) {
+ const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
+ const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+
+ upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
+}
+
+void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ const int mode_flags = dst->op_params[0];
+ const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF);
+
+ float sf0 = (float)dst->ne[0]/src0->ne[0];
+ float sf1 = (float)dst->ne[1]/src0->ne[1];
+ float sf2 = (float)dst->ne[2]/src0->ne[2];
+ const float sf3 = (float)dst->ne[3]/src0->ne[3];
+
+ float pixel_offset = 0.5f;
+ if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
+ sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
+ sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
+ pixel_offset = 0.0f;
+ }
+
+ if (mode == GGML_SCALE_MODE_NEAREST) {
+ upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
+ } else if (mode == GGML_SCALE_MODE_BILINEAR) {
+ const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS);
+ upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ sf0, sf1, sf2, sf3, pixel_offset, antialias, stream);
+ } else if (mode == GGML_SCALE_MODE_BICUBIC) {
+ upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+ src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
+ sf0, sf1, sf2, sf3, pixel_offset, stream);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/upscale.cuh b/llama.cpp/ggml/src/ggml-cuda/upscale.cuh
new file mode 100644
index 0000000..d4d7652
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/upscale.cuh
@@ -0,0 +1,5 @@
+#include "common.cuh"
+
+#define CUDA_UPSCALE_BLOCK_SIZE 256
+
+void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh b/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh
new file mode 100644
index 0000000..6baab11
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/vecdotq.cuh
@@ -0,0 +1,1223 @@
+#pragma once
+
+#include "common.cuh"
+
+#include <cstdint>
+
+static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
+ const uint8_t * x8 = (const uint8_t *) x;
+
+ int x32 = x8[4*i32 + 0] << 0;
+ x32 |= x8[4*i32 + 1] << 8;
+ x32 |= x8[4*i32 + 2] << 16;
+ x32 |= x8[4*i32 + 3] << 24;
+
+ return x32;
+}
+
+static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
+ const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
+
+ int x32 = x16[2*i32 + 0] << 0;
+ x32 |= x16[2*i32 + 1] << 16;
+
+ return x32;
+}
+
+static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32) {
+ return ((const int *) x)[i32]; // assume at least 4 byte alignment
+}
+
+// q4 contains 8 indices with 4 bit each.
+// This function selects those bytes from table that are at those indices and returns them as int2.
+// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+#if defined(GGML_USE_HIP)
+ // Load the 16-byte table into four 32-bit unsigned integers.
+ const uint32_t *values = (const uint32_t *)table;
+
+ const uint32_t q_even = q4;
+ const uint32_t q_odd = (q4 >> 4);
+
+ // Perform lookups in the lower half of the table (indices 0-7).
+ uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
+ uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
+
+ // Perform lookups in the upper half of the table (indices 8-15).
+ uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
+ uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
+
+ // Select between the low and high results based on the MSB of each index nibble.
+ uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
+ uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
+ uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
+ uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
+
+ return make_int2(res_x, res_y);
+#elif !defined(GGML_USE_MUSA)
+ // CUDA does not have an instruction for selecting bytes with 4 bit indices.
+ // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
+ const uint32_t * table32 = (const uint32_t *) table;
+
+ // __byte_perm selects bytes based on the lower 16 bits in its third argument.
+ // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
+ // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
+ // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
+ uint32_t tmp[2];
+ const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
+#pragma unroll
+ for (uint32_t i = 0; i < 2; ++i) {
+ const uint32_t shift = 16 * i;
+
+ const uint32_t low = __byte_perm(table32[0], table32[1], q4 >> shift);
+ const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
+ tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
+ }
+
+ // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
+ // However, for the result we need ints with all even/odd 4 bit indices in q4.
+ // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
+ return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
+#else
+ // Generic implementation.
+ const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
+ const int8_t * q0_8 = (const int8_t *) &q0_32;
+ const char4 val0_8 = make_char4(
+ table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
+
+ const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
+ const int8_t * q1_8 = (const int8_t *) &q1_32;
+ const char4 val1_8 = make_char4(
+ table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
+
+ return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+#endif
+}
+
+// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
+// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
+
+#define VDR_Q4_0_Q8_1_MMVQ 2
+#define VDR_Q4_0_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl(
+ const int * v, const int * u, const float & d4, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 8 from each quant value
+ return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
+}
+
+#define VDR_Q4_1_Q8_1_MMVQ 2
+#define VDR_Q4_1_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl(
+ const int * v, const int * u, const half2 & dm4, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ const int vi0 = (v[i] >> 0) & 0x0F0F0F0F;
+ const int vi1 = (v[i] >> 4) & 0x0F0F0F0F;
+
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi);
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi);
+ }
+
+#ifdef FAST_FP16_AVAILABLE
+ const float2 tmp = __half22float2(__hmul2(dm4, ds8));
+ const float d4d8 = tmp.x;
+ const float m4s8 = tmp.y;
+#else
+ const float2 dm4f = __half22float2(dm4);
+ const float2 ds8f = __half22float2(ds8);
+ const float d4d8 = dm4f.x * ds8f.x;
+ const float m4s8 = dm4f.y * ds8f.y;
+#endif // FAST_FP16_AVAILABLE
+
+ // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
+ return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
+}
+
+#define VDR_Q5_0_Q8_1_MMVQ 2
+#define VDR_Q5_0_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+ }
+
+ const float2 ds8f = __half22float2(ds8);
+
+ // second part effectively subtracts 16 from each quant value
+ return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
+}
+
+#define VDR_Q5_1_Q8_1_MMVQ 2
+#define VDR_Q5_1_Q8_1_MMQ 4
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl(
+ const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits
+ vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4
+ vi0 |= (vh[i] << 11) & 0x00001000; // 1 -> 12
+ vi0 |= (vh[i] << 18) & 0x00100000; // 2 -> 20
+ vi0 |= (vh[i] << 25) & 0x10000000; // 3 -> 28
+ sumi = ggml_cuda_dp4a(vi0, u[2*i+0], sumi); // SIMD dot product of quantized values
+
+ int vi1 = (vl[i] >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh as 5th bits
+ vi1 |= (vh[i] >> 12) & 0x00000010; // 16 -> 4
+ vi1 |= (vh[i] >> 5) & 0x00001000; // 17 -> 12
+ vi1 |= (vh[i] << 2) & 0x00100000; // 18 -> 20
+ vi1 |= (vh[i] << 9) & 0x10000000; // 19 -> 28
+ sumi = ggml_cuda_dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values
+ }
+
+#ifdef FAST_FP16_AVAILABLE
+ const float2 tmp = __half22float2(__hmul2(dm5, ds8));
+ const float d5d8 = tmp.x;
+ const float m5s8 = tmp.y;
+#else
+ const float2 dm5f = __half22float2(dm5);
+ const float2 ds8f = __half22float2(ds8);
+ const float d5d8 = dm5f.x * ds8f.x;
+ const float m5s8 = dm5f.y * ds8f.y;
+#endif // FAST_FP16_AVAILABLE
+
+ // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it
+ return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
+}
+
+#define VDR_Q8_0_Q8_1_MMVQ 2
+#define VDR_Q8_0_Q8_1_MMQ 8
+
+template <typename T, int vdr> static __device__ __forceinline__ T vec_dot_q8_0_q8_1_impl(
+ const int * v, const int * u, const T & d8_0, const T & d8_1) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+ return d8_0*d8_1 * ((T) sumi);
+}
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl(
+ const int * v, const int * u, const half2 & dm8, const half2 & ds8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+#ifdef FAST_FP16_AVAILABLE
+ const float2 tmp = __half22float2(__hmul2(dm8, ds8));
+ const float d8d8 = tmp.x;
+ const float m8s8 = tmp.y;
+#else
+ const float2 dm8f = __half22float2(dm8);
+ const float2 ds8f = __half22float2(ds8);
+ const float d8d8 = dm8f.x * ds8f.x;
+ const float m8s8 = dm8f.y * ds8f.y;
+#endif // FAST_FP16_AVAILABLE
+
+ // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
+ return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
+}
+
+template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl(
+ const int * v, const int * u, const float * d8_0, const float & d8_1) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) {
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_0/2; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_cuda_dp4a(v[i], u[i], sumi);
+ }
+
+ sumf += d8_0[i0/(QI8_0/2)]*sumi;
+ }
+
+ return d8_1*sumf;
+}
+
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
+
+ const int * q8 = (const int *) bq8_1->qs + iqs;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
+ const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+
+ sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+ sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+ }
+
+ const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
+ return d * sumi;
+}
+
+#define VDR_Q2_K_Q8_1_MMVQ 1
+#define VDR_Q2_K_Q8_1_MMQ 4
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
+ const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const half2 & dm2, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++i) {
+ const int sc = scales[2*i];
+
+ const int vi = (v >> (2*i)) & 0x03030303;
+
+ sumf_d += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product
+
+ // fill int with 4x m
+ int m = sc >> 4;
+ m |= m << 8;
+ m |= m << 16;
+ sumf_m += d8[i] * ggml_cuda_dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values
+ }
+
+ const float2 dm2f = __half22float2(dm2);
+
+ return dm2f.x*sumf_d - dm2f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+template <int ns8>
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
+
+ float sumf = 0.0f;
+ float sumf_d8 = 0.0f;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
+ const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
+ int sumi_d0 = 0;
+
+ const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
+ int sumi_d1 = 0;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
+ }
+ sumf_d8 += dm2f0.x * sumi_d0;
+
+#pragma unroll
+ for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+ sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
+ }
+ sumf_d8 += dm2f1.x * sumi_d1;
+
+ if (i0/QI8_1 < ns8) {
+ const float2 s8f = __half22float2(s8[i0/QI8_1]);
+ sumf -= dm2f0.y*s8f.x;
+ sumf -= dm2f1.y*s8f.y;
+ } else {
+ int sumi_m0 = 0;
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
+ }
+ sumf_d8 -= dm2f0.y * sumi_m0;
+
+ int sumi_m1 = 0;
+#pragma unroll
+ for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
+ sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
+ }
+ sumf_d8 -= dm2f1.y * sumi_m1;
+ }
+ }
+
+ return sumf + d8*sumf_d8;
+}
+
+#define VDR_Q3_K_Q8_1_MMVQ 1
+#define VDR_Q3_K_Q8_1_MMQ 2
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
+ const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales,
+ const int & scale_offset, const float & d3, const float * __restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ const int isc = scale_offset + 2*i;
+
+ const int isc_low = isc % (QK_K/32);
+ const int sc_shift_low = 4 * (isc / (QK_K/32));
+ const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF;
+
+ const int isc_high = isc % (QK_K/64);
+ const int sc_shift_high = 2 * (isc / (QK_K/64));
+ const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
+
+ const int sc = (sc_low | sc_high) - 32;
+
+ const int vil = (vl >> (2*i)) & 0x03030303;
+
+ const int vih = ((vh >> i) << 2) & 0x04040404;
+
+ const int vi = __vsubss4(vil, vih);
+
+ sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d3 * sumf;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
+ const float & d3, const float & d8) {
+
+ int sumi = 0;
+
+#pragma unroll
+ for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
+ int sumi_sc = 0;
+
+#pragma unroll
+ for (int i = i0; i < i0 + QI8_1/2; ++i) {
+ sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
+ }
+
+ sumi += sumi_sc * scales[i0 / (QI8_1/2)];
+ }
+
+ return d3*d8 * sumi;
+}
+
+#define VDR_Q4_K_Q8_1_MMVQ 2
+#define VDR_Q4_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K; ++i) {
+ const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F;
+ const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int dot1 = ggml_cuda_dp4a(v1i, u[2*i+1], ggml_cuda_dp4a(v0i, u[2*i+0], 0)); // SIMD dot product
+ const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+1], ggml_cuda_dp4a(0x01010101, u[2*i+0], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = ggml_cuda_dp4a((v[j] >> (4*i)) & 0x0F0F0F0F, u[i*QI8_1 + j], sumi_d); // SIMD dot product
+ }
+
+ const float2 ds8f = __half22float2(ds8[i]);
+
+ sumf_d += ds8f.x * (sc[i] * sumi_d);
+ sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q5_K_Q8_1_MMVQ 2
+#define VDR_Q5_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
+ const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F;
+ const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F;
+
+ const int vh0i = ((vh[0] >> i) << 4) & 0x10101010;
+ const int vh1i = ((vh[1] >> i) << 4) & 0x10101010;
+
+ const int v0i = vl0i | vh0i;
+ const int v1i = vl1i | vh1i;
+
+ const int dot1 = ggml_cuda_dp4a(v0i, u[2*i+0], ggml_cuda_dp4a(v1i, u[2*i+1], 0)); // SIMD dot product
+ const int dot2 = ggml_cuda_dp4a(0x01010101, u[2*i+0], ggml_cuda_dp4a(0x01010101, u[2*i+1], 0)); // sum of u
+
+ sumf_d += d8[i] * (dot1 * sc[i]);
+ sumf_m += d8[i] * (dot2 * m[i]);
+
+ }
+
+ const float2 dm5f = __half22float2(dm5);
+
+ return dm5f.x*sumf_d - dm5f.y*sumf_m;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
+ const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
+
+ float sumf_d = 0.0f;
+ float sumf_m = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
+ int sumi_d = 0;
+
+#pragma unroll
+ for (int j = 0; j < QI8_1; ++j) {
+ sumi_d = ggml_cuda_dp4a(v[i*QI8_1 + j], u[i*QI8_1 + j], sumi_d); // SIMD dot product
+ }
+
+ const float2 ds8f = __half22float2(ds8[i]);
+
+ sumf_d += ds8f.x * (sc[i] * sumi_d);
+ sumf_m += ds8f.y * m[i]; // sum of q8_1 block * q4_K min val
+ }
+
+ const float2 dm4f = __half22float2(dm4);
+
+ return dm4f.x*sumf_d - dm4f.y*sumf_m;
+}
+
+#define VDR_Q6_K_Q8_1_MMVQ 1
+#define VDR_Q6_K_Q8_1_MMQ 8
+
+// contiguous v/x values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
+ const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales,
+ const float & d, const float * __restrict__ d8) {
+
+ float sumf = 0.0f;
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ const int sc = scales[4*i];
+
+ const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
+
+ const int vih = ((vh >> (4*i)) << 4) & 0x30303030;
+
+ const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
+
+ sumf += d8[i] * (ggml_cuda_dp4a(vi, u[i], 0) * sc); // SIMD dot product
+ }
+
+ return d*sumf;
+}
+
+// contiguous v/x + u/y values
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
+ const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
+ const float & d6, const float * __restrict__ d8) {
+
+ float sumf_d = 0.0f;
+
+ const int sc_packed = get_int_b4(sc, 0);
+ const int8_t * sc_reg = (const int8_t *) &sc_packed;
+
+#pragma unroll
+ for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
+ int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
+
+#pragma unroll
+ for (int i = i0; i < i0 + 2; ++i) {
+ sumi_d.x = ggml_cuda_dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product
+ sumi_d.x = ggml_cuda_dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product
+
+ sumi_d.y = ggml_cuda_dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product
+ sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
+ }
+
+ sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
+ }
+
+ return d6 * sumf_d;
+}
+
+static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
+
+ int v[VDR_Q4_0_Q8_1_MMVQ];
+ int u[2*VDR_Q4_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_b2(bq4_0->qs, iqs + i);
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_0);
+ }
+
+ return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
+}
+
+
+static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
+
+ int v[VDR_Q4_1_Q8_1_MMVQ];
+ int u[2*VDR_Q4_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q4_1_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_b4(bq4_1->qs, iqs + i);
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI4_1);
+ }
+
+ return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
+
+ int vl[VDR_Q5_0_Q8_1_MMVQ];
+ int vh[VDR_Q5_0_Q8_1_MMVQ];
+ int u[2*VDR_Q5_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_0_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_b2(bq5_0->qs, iqs + i);
+ vh[i] = get_int_b2(bq5_0->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_0);
+ }
+
+ return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
+
+ int vl[VDR_Q5_1_Q8_1_MMVQ];
+ int vh[VDR_Q5_1_Q8_1_MMVQ];
+ int u[2*VDR_Q5_1_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q5_1_Q8_1_MMVQ; ++i) {
+ vl[i] = get_int_b4(bq5_1->qs, iqs + i);
+ vh[i] = get_int_b4(bq5_1->qh, 0) >> (4 * (iqs + i));
+ u[2*i+0] = get_int_b4(bq8_1->qs, iqs + i);
+ u[2*i+1] = get_int_b4(bq8_1->qs, iqs + i + QI5_1);
+ }
+
+ return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
+}
+
+static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
+
+ int v[VDR_Q8_0_Q8_1_MMVQ];
+ int u[VDR_Q8_0_Q8_1_MMVQ];
+
+#pragma unroll
+ for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) {
+ v[i] = get_int_b2(bq8_0->qs, iqs + i);
+ u[i] = get_int_b4(bq8_1->qs, iqs + i);
+ }
+
+ return vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
+}
+
+static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
+
+ const int bq8_offset = QR2_K * (iqs / QI8_1);
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const uint8_t * scales = bq2_K->scales + scale_offset;
+
+ const int v = get_int_b4(bq2_K->qs, iqs);
+ int u[QR2_K];
+ float d8[QR2_K];
+
+#pragma unroll
+ for (int i = 0; i < QR2_K; ++ i) {
+ u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+ }
+
+ return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
+
+ const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
+ const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
+
+ const float d = bq3_K->d;
+
+ const int vl = get_int_b2(bq3_K->qs, iqs);
+
+ // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
+ const int vh = ~get_int_b2(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset;
+
+ int u[QR3_K];
+ float d8[QR3_K];
+
+#pragma unroll
+ for (int i = 0; i < QR3_K; ++i) {
+ u[i] = get_int_b4(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + i].ds);
+ }
+
+ return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
+
+ int v[2];
+ int u[2*QR4_K];
+ float d8[QR4_K];
+
+ // iqs is in 0,2..30. bq8_offset = iqs/4 -> bq8_offset = 0, 2, 4, 6
+ const int bq8_offset = QR4_K * ((iqs/2) / (QI8_1/2));
+
+ // iqs = 0....3 -> bq8_offset = 0, want q4_offset = 0, 4, 8, 12
+ // iqs = 4....7 -> bq8_offset = 2, want q4_offset = 32, 36, 40, 44
+ // iqs = 8...11 -> bq8_offset = 4, want q4_offset = 64, 68, 72, 76
+ // iqs = 12..15 -> bq8_offset = 6, want q4_offset = 96, 100, 104, 108
+
+ const int * q4 = (const int *)(bq4_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ v[0] = q4[0];
+ v[1] = q4[4];
+
+ const uint16_t * scales = (const uint16_t *)bq4_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+ for (int i = 0; i < QR4_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = __low2float(bq8i->ds);
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
+
+ int vl[2];
+ int vh[2];
+ int u[2*QR5_K];
+ float d8[QR5_K];
+
+ const int bq8_offset = QR5_K * ((iqs/2) / (QI8_1/2));
+ const int * ql = (const int *)(bq5_K->qs + 16 * bq8_offset + 4 * ((iqs/2)%4));
+ const int * qh = (const int *)(bq5_K->qh + 4 * ((iqs/2)%4));
+
+ vl[0] = ql[0];
+ vl[1] = ql[4];
+
+ vh[0] = qh[0] >> bq8_offset;
+ vh[1] = qh[4] >> bq8_offset;
+
+ const uint16_t * scales = (const uint16_t *)bq5_K->scales;
+ uint16_t aux[2];
+ const int j = bq8_offset/2;
+ if (j < 2) {
+ aux[0] = scales[j+0] & 0x3f3f;
+ aux[1] = scales[j+2] & 0x3f3f;
+ } else {
+ aux[0] = ((scales[j+2] >> 0) & 0x0f0f) | ((scales[j-2] & 0xc0c0) >> 2);
+ aux[1] = ((scales[j+2] >> 4) & 0x0f0f) | ((scales[j-0] & 0xc0c0) >> 2);
+ }
+ const uint8_t * sc = (const uint8_t *)aux;
+ const uint8_t * m = sc + 2;
+
+#pragma unroll
+ for (int i = 0; i < QR5_K; ++i) {
+ const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
+ d8[i] = __low2float(bq8i->ds);
+
+ const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
+ u[2*i+0] = q8[0];
+ u[2*i+1] = q8[4];
+ }
+
+ return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, bq5_K->dm, d8);
+}
+
+static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
+
+ const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
+ const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
+ const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
+
+ const int vl = get_int_b2(bq6_K->ql, iqs);
+ const int vh = get_int_b2(bq6_K->qh, (QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4)) >> vh_shift;
+
+ const int8_t * scales = bq6_K->scales + scale_offset;
+
+ int u[QR6_K];
+ float d8[QR6_K];
+
+#pragma unroll
+ for (int i = 0; i < QR6_K; ++i) {
+ u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
+ d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
+ }
+
+ return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
+}
+
+#define VDR_IQ2_XXS_Q8_1_MMVQ 2
+#define VDR_IQ2_XXS_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
+
+ const int q2 = get_int_b2(bq2->qs, iqs);
+ const uint8_t * aux8 = (const uint8_t *) &q2;
+ const uint32_t aux32 = get_int_b2(bq2->qs, iqs + 1);
+
+ int sumi = 0;
+#pragma unroll
+ for (int k0 = 0; k0 < 8; k0 += 2) {
+ const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
+ const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
+
+ const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
+ const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
+ sumi = ggml_cuda_dp4a(grid0, u0, sumi);
+
+ const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
+ const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
+ sumi = ggml_cuda_dp4a(grid1, u1, sumi);
+ }
+
+ const int ls = aux32 >> 28;
+ sumi = (ls*sumi + sumi/2)/4;
+ const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ2_XS_Q8_1_MMVQ 2
+#define VDR_IQ2_XS_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
+
+ const int2 q2_packed = make_int2(get_int_b2(bq2->qs, iqs + 0), get_int_b2(bq2->qs, iqs + 1));
+ const uint16_t * q2 = (const uint16_t *) &q2_packed;
+ const int ls0 = bq2->scales[iqs/2] & 0x0F;
+ const int ls1 = bq2->scales[iqs/2] >> 4;
+
+ int sumi0 = 0;
+ int sumi1 = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
+ const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9));
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ if (l0 < 4) {
+ sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
+ sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
+ } else {
+ sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
+ sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
+ }
+ }
+ const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
+ const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ2_S_Q8_1_MMVQ 2
+#define VDR_IQ2_S_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
+
+ const int qs_packed = get_int_b2(bq2->qs, iqs/2);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bq2->qh[iqs/2];
+
+ const int signs_packed_32 = get_int_b2(bq2->qs, QK_K/32 + iqs/2);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+ const int ls0 = bq2->scales[iqs/2] & 0x0F;
+ const int ls1 = bq2->scales[iqs/2] >> 4;
+
+ int sumi0 = 0;
+ int sumi1 = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int * grid_pos = (const int *)(iq2s_grid + (qs[l0/2] | ((qh << (8-l0)) & 0x300)));
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ if (l0 < 4) {
+ sumi0 = ggml_cuda_dp4a(grid_l, u0, sumi0);
+ sumi0 = ggml_cuda_dp4a(grid_h, u1, sumi0);
+ } else {
+ sumi1 = ggml_cuda_dp4a(grid_l, u0, sumi1);
+ sumi1 = ggml_cuda_dp4a(grid_h, u1, sumi1);
+ }
+ }
+ const int sumi = (sumi0*ls0 + sumi1*ls1 + (sumi0 + sumi1)/2)/4;
+
+ const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ3_XXS_Q8_1_MMVQ 2
+#define VDR_IQ3_XXS_Q8_1_MMQ 2
+
+static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq3_xxs * bq3 = (const block_iq3_xxs *) vbq + kbx;
+
+ const int2 q3_packed = make_int2(get_int_b2(bq3->qs, iqs), get_int_b2(bq3->qs, iqs+1));
+ const uint8_t * q3 = (const uint8_t *) &q3_packed;
+ const uint32_t aux32 = get_int_b2(bq3->qs, QK_K/16 + iqs/2);
+
+ int sumi = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
+
+ const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
+ const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+ sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
+ }
+
+ const int ls = aux32 >> 28;
+ sumi = (ls*sumi + sumi/2)/2;
+ const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ3_S_Q8_1_MMVQ 2
+#define VDR_IQ3_S_Q8_1_MMQ 2
+
+// TODO: don't use lookup table for signs
+static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq3_s * bq3 = (const block_iq3_s *) vbq + kbx;
+
+ const int2 qs_packed = make_int2(get_int_b2(bq3->qs, iqs + 0), get_int_b2(bq3->qs, iqs + 1));
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bq3->qh[iqs/2];
+
+ const int signs_packed_32 = get_int_b2(bq3->signs, iqs/2);
+ const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int2 grid_pos = make_int2(
+ iq3s_grid[qs[l0 + 0] | ((qh << (8 - l0)) & 0x100)],
+ iq3s_grid[qs[l0 + 1] | ((qh << (7 - l0)) & 0x100)]);
+
+ const int signs0 = __vcmpne4(((signs_packed_8[l0/2] & 0x03) << 7) | ((signs_packed_8[l0/2] & 0x0C) << 21), 0x00000000);
+ const int signs1 = __vcmpne4(((signs_packed_8[l0/2] & 0x30) << 3) | ((signs_packed_8[l0/2] & 0xC0) << 17), 0x00000000);
+
+ const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
+ const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
+ const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
+
+ sumi = ggml_cuda_dp4a(grid_l, u0, sumi);
+ sumi = ggml_cuda_dp4a(grid_h, u1, sumi);
+ }
+
+ sumi *= 1 + 2*((bq3->scales[iqs/4] >> ((iqs << 1) & 0x04)) & 0x0F);
+
+ const float d = __half2float(bq3->d) * __low2float(bq8_1[iqs/2].ds);
+ return d * sumi;
+}
+
+#define VDR_IQ1_S_Q8_1_MMVQ 1
+#define VDR_IQ1_S_Q8_1_MMQ 1
+
+static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+ const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
+
+ const int qs_packed = get_int_b2(bq1->qs, iqs);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ const int qh = bq1->qh[iqs];
+
+ int sumi = 0;
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int grid = iq1s_grid_gpu[qs[l0/2] | (((qh >> 3*(l0/2)) & 0x07) << 8)];
+
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+ const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
+
+ sumi = ggml_cuda_dp4a(grid0, u0, sumi);
+ sumi = ggml_cuda_dp4a(grid1, u1, sumi);
+ }
+
+ const float d1q = __half2float(bq1->d) * (((qh >> 11) & 0x0E) + 1);
+ const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000);
+ const float2 ds = __half22float2(bq8_1[iqs].ds);
+ return d1q * (ds.x*sumi + ds.y*delta);
+}
+
+#define VDR_IQ1_M_Q8_1_MMVQ 1
+#define VDR_IQ1_M_Q8_1_MMQ 1
+
+static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
+
+ const int qs_packed = get_int_b4(bq1->qs, iqs);
+ const uint8_t * qs = (const uint8_t *) &qs_packed;
+
+ int sumi[2] = {0};
+ float sumf[2] = {0.0f};
+#pragma unroll
+ for (int l0 = 0; l0 < 8; l0 += 2) {
+ const int qhl = bq1->qh[2*iqs + l0/4] >> (4 * ((l0/2) % 2));
+
+ const int grid = iq1s_grid_gpu[qs[l0/2] | ((qhl & 0x07) << 8)];
+
+ const int grid0 = (grid >> 0) & 0x0F0F0F0F;
+ const int grid1 = (grid >> 4) & 0x0F0F0F0F;
+
+ const int u0 = get_int_b4(bq8_1[iqs].qs, l0 + 0);
+ const int u1 = get_int_b4(bq8_1[iqs].qs, l0 + 1);
+
+ sumi[l0/4] = ggml_cuda_dp4a(grid0, u0, sumi[l0/4]);
+ sumi[l0/4] = ggml_cuda_dp4a(grid1, u1, sumi[l0/4]);
+
+ const float delta = -1.0f + IQ1M_DELTA - (qhl & 0x08) * (2.0f*IQ1M_DELTA/0x08);
+ int sumy = 0;
+ sumy = ggml_cuda_dp4a(u0, 0x01010101, sumy);
+ sumy = ggml_cuda_dp4a(u1, 0x01010101, sumy);
+ sumf[l0/4] += delta*sumy;
+ }
+
+ const uint16_t * sc = (const uint16_t *) bq1->scales;
+
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00F0) | ((sc[2] >> 4) & 0x0F00) | (sc[3] & 0xF000);
+ const float d = __half2float(scale.f16) * __low2float(bq8_1[iqs].ds);
+
+ const int tmp = sc[iqs/2] >> (6*(iqs%2));
+ const int sc0 = 2*((tmp >> 0) & 0x07) + 1;
+ const int sc1 = 2*((tmp >> 3) & 0x07) + 1;
+ return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
+}
+
+#define VDR_IQ4_NL_Q8_1_MMVQ 2
+#define VDR_IQ4_NL_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq4_nl * bq4 = (const block_iq4_nl *) vbq + kbx;
+
+ const int * q8 = (const int *) bq8_1->qs + iqs;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
+ const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
+
+ sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+ sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+ }
+
+ const float d = __half2float(bq4->d) * __low2float(bq8_1->ds);
+ return d * sumi;
+}
+
+#define VDR_IQ4_XS_Q8_1_MMVQ 4
+#define VDR_IQ4_XS_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
+
+ int sumi = 0;
+#pragma unroll
+ for (int j = 0; j < 4; ++j) {
+ const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
+
+ const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
+ const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
+
+ sumi = ggml_cuda_dp4a(v.x, u0, sumi);
+ sumi = ggml_cuda_dp4a(v.y, u1, sumi);
+ }
+
+ const int ls = ((bq4->scales_l[iqs/8] >> (iqs & 0x04)) & 0x0F) | (((bq4->scales_h >> (iqs/2)) & 0x03) << 4);
+ sumi *= ls - 32;
+
+ const float d = __half2float(bq4->d) * __low2float(bq8_1[iqs/4].ds);
+ return d * sumi;
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h b/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h
new file mode 100644
index 0000000..ba032cf
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h
@@ -0,0 +1,23 @@
+#pragma once
+
+#include <cuda_runtime.h>
+#include <cuda.h>
+#include <cublas_v2.h>
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+
+#if CUDART_VERSION >= 12050
+#include <cuda_fp8.h>
+#endif // CUDART_VERSION >= 12050
+
+#if CUDART_VERSION >= 12080
+#include <cuda_fp4.h>
+#endif // CUDART_VERSION >= 12080
+
+#if CUDART_VERSION < 11020
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define cublasComputeType_t cudaDataType_t
+#endif // CUDART_VERSION < 11020
diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h b/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h
new file mode 100644
index 0000000..5cc1b54
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h
@@ -0,0 +1,278 @@
+#pragma once
+
+#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
+#include <hip/hip_runtime.h>
+#include <hipblas/hipblas.h>
+#include <hip/hip_fp16.h>
+#include <hip/hip_bf16.h>
+
+#if defined(GGML_HIP_ROCWMMA_FATTN)
+#include <rocwmma/rocwmma-version.hpp>
+#endif // defined(GGML_HIP_ROCWMMA_FATTN)
+
+#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N HIPBLAS_OP_N
+#define CUBLAS_OP_T HIPBLAS_OP_T
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH 0
+#define CUDA_R_16F HIPBLAS_R_16F
+#define CUDA_R_16BF HIPBLAS_R_16B
+#define CUDA_R_32F HIPBLAS_R_32F
+#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
+#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER
+#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
+#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
+#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
+#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
+#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
+#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
+#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define __all_sync(mask, var) __all(var)
+#define __any_sync(mask, var) __any(var)
+#define cublasStrsmBatched hipblasStrsmBatched
+#define cublasCreate hipblasCreate
+#define cublasDestroy hipblasDestroy
+#define cublasGemmEx hipblasGemmEx
+#define cublasGemmBatchedEx hipblasGemmBatchedEx
+#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
+#define cublasHandle_t hipblasHandle_t
+#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
+#define cublasSetStream hipblasSetStream
+#define cublasSgemm hipblasSgemm
+#define cublasStatus_t hipblasStatus_t
+#define cublasOperation_t hipblasOperation_t
+#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
+#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
+#define cudaDeviceGetAttribute hipDeviceGetAttribute
+#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceSynchronize hipDeviceSynchronize
+#define cudaError_t hipError_t
+#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags hipEventCreateWithFlags
+#define cudaEventDisableTiming hipEventDisableTiming
+#define cudaEventRecord hipEventRecord
+#define cudaEventSynchronize hipEventSynchronize
+#define cudaEvent_t hipEvent_t
+#define cudaEventDestroy hipEventDestroy
+#define cudaFree hipFree
+#define cudaFreeHost hipHostFree
+#define cudaGetDevice hipGetDevice
+#define cudaGetDeviceCount hipGetDeviceCount
+#define cudaGetDeviceProperties hipGetDeviceProperties
+#define cudaGetErrorString hipGetErrorString
+#define cudaGetLastError hipGetLastError
+#define cudaHostRegister hipHostRegister
+#define cudaHostRegisterPortable hipHostRegisterPortable
+#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
+#define cudaHostUnregister hipHostUnregister
+#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
+#define cudaLaunchHostFunc hipLaunchHostFunc
+#define cudaMalloc hipMalloc
+#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
+#define cudaMallocManaged hipMallocManaged
+#define cudaMemAdvise hipMemAdvise
+#define cudaMemcpy hipMemcpy
+#define cudaMemcpyAsync hipMemcpyAsync
+#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
+#define cudaMemcpy2DAsync hipMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
+#define cudaMemcpyKind hipMemcpyKind
+#define cudaMemset hipMemset
+#define cudaMemsetAsync hipMemsetAsync
+#define cudaMemGetInfo hipMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
+#define cudaSetDevice hipSetDevice
+#define cuDeviceGet hipDeviceGet
+#define CUdevice hipDevice_t
+#define CUdeviceptr hipDeviceptr_t
+#define cuMemUnmap hipMemUnmap
+#define CUmemAccessDesc hipMemAccessDesc
+#define cuMemAddressFree hipMemAddressFree
+#define cuMemRelease hipMemRelease
+#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
+#define cuMemCreate hipMemCreate
+#define cuMemAddressReserve hipMemAddressReserve
+#define cuMemMap hipMemMap
+#define cuMemSetAccess hipMemSetAccess
+#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
+#define CUmemAllocationProp hipMemAllocationProp
+#define cuDeviceGetAttribute hipDeviceGetAttribute
+#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
+#define cudaStreamDestroy hipStreamDestroy
+#define cudaStreamFireAndForget hipStreamFireAndForget
+#define cudaStreamNonBlocking hipStreamNonBlocking
+#define cudaStreamPerThread hipStreamPerThread
+#define cudaStreamSynchronize hipStreamSynchronize
+#define cudaStreamWaitEvent hipStreamWaitEvent
+#define cudaGraphExec_t hipGraphExec_t
+#define cudaGraphNode_t hipGraphNode_t
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaKernelNodeParams hipKernelNodeParams
+#define cudaGraphExecDestroy hipGraphExecDestroy
+#define cudaGraphLaunch hipGraphLaunch
+#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
+#define cudaGraphNodeType hipGraphNodeType
+#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
+#define cudaGraphInstantiate hipGraphInstantiate
+#define cudaStreamEndCapture hipStreamEndCapture
+#define cudaGraphDestroy hipGraphDestroy
+#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
+#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
+#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
+#define cudaGraphNodeGetType hipGraphNodeGetType
+#define cudaGraphGetNodes hipGraphGetNodes
+#define cudaGraphExecUpdate hipGraphExecUpdate
+#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture hipStreamBeginCapture
+#define cudaGraph_t hipGraph_t
+#define cudaStream_t hipStream_t
+#define cudaSuccess hipSuccess
+#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
+#define cudaFuncSetAttribute hipFuncSetAttribute
+#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize
+#define __trap() do { abort(); __builtin_unreachable(); } while(0)
+#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
+#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
+#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
+#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
+#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
+#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
+#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
+#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
+#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
+
+#if HIP_VERSION >= 60500000
+#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
+#define cublasComputeType_t hipblasComputeType_t
+#define cudaDataType_t hipDataType
+#else
+#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
+#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
+#define cublasComputeType_t hipblasDatatype_t
+#define cudaDataType_t hipblasDatatype_t
+#endif // HIP_VERSION >= 6050000
+
+#if !defined(__HIP_PLATFORM_AMD__)
+#error "The HIP backend supports only AMD targets"
+#endif // !defined(__HIP_PLATFORM_AMD__)
+
+#define __CUDA_ARCH__ 1300
+
+#if defined(__gfx900__) || defined(__gfx906__)
+#define GCN5
+#endif // defined(__gfx900__) || defined(__gfx906__)
+
+#if defined(__gfx803__)
+#define GCN4
+#endif // defined(__gfx803__)
+
+#if defined(GCN5) || defined(GCN4)
+#define GCN
+#endif // defined(GCN5) || defined(GCN4)
+
+#if defined(__gfx942__)
+#define CDNA3
+#endif // defined(__gfx942__)
+
+#if defined(__gfx90a__)
+#define CDNA2
+#endif // defined(__gfx90a__)
+
+#if defined(__gfx908__)
+#define CDNA1
+#endif // defined(__gfx908__)
+
+#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
+#define CDNA // For the entire family
+#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
+
+#if defined(__GFX12__)
+#define RDNA4
+#endif // defined(__GFX12__)
+
+#if defined(__GFX11__)
+#define RDNA3
+#endif // defined(__GFX11__)
+
+#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
+ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
+#define RDNA2
+#endif
+
+#if defined(__gfx1010__) || defined(__gfx1012__)
+#define RDNA1
+#endif // defined(__gfx1010__) || defined(__gfx1012__)
+
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
+#define RDNA // For the entire family
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
+
+#ifndef __has_builtin
+ #define __has_builtin(x) 0
+#endif
+
+typedef __hip_bfloat16 nv_bfloat16;
+typedef __hip_bfloat162 nv_bfloat162;
+
+typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
+typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
+static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
+ const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
+ const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
+#if __has_builtin(__builtin_elementwise_sub_sat)
+ const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
+ return reinterpret_cast<const int &>(c);
+#else
+ int8x4_t c;
+ int16_t tmp;
+#pragma unroll
+ for (int i = 0; i < 4; i++) {
+ tmp = va[i] - vb[i];
+ if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
+ if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
+ c[i] = tmp;
+ }
+ return reinterpret_cast<int &>(c);
+#endif // __has_builtin(__builtin_elementwise_sub_sat)
+}
+
+static __device__ __forceinline__ int __vsub4(const int a, const int b) {
+ return __vsubss4(a, b);
+}
+
+static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+ unsigned int c;
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+ for (int i = 0; i < 4; ++i) {
+ vc[i] = va[i] == vb[i] ? 0xff : 0x00;
+ }
+ return c;
+}
+
+static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
+ const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
+ const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
+ unsigned int c;
+ uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
+#pragma unroll
+ for (int i = 0; i < 4; ++i) {
+ vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
+ }
+ return c;
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h b/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h
new file mode 100644
index 0000000..1abb8ac
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h
@@ -0,0 +1,147 @@
+#pragma once
+
+#include <musa_runtime.h>
+#include <musa.h>
+#include <mublas.h>
+#include <musa_bf16.h>
+#include <musa_fp16.h>
+#define CUBLAS_COMPUTE_16F CUDA_R_16F
+#define CUBLAS_COMPUTE_32F CUDA_R_32F
+#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
+#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
+#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
+#define CUBLAS_OP_N MUBLAS_OP_N
+#define CUBLAS_OP_T MUBLAS_OP_T
+#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
+#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
+#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
+#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
+#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
+#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
+#define CUDA_R_16F MUSA_R_16F
+#define CUDA_R_16BF MUSA_R_16BF
+#define CUDA_R_32F MUSA_R_32F
+#define cublasStrsmBatched mublasStrsmBatched
+#define cublasComputeType_t cudaDataType_t
+#define cublasCreate mublasCreate
+#define cublasDestroy mublasDestroy
+#define cublasGemmEx mublasGemmEx
+#define cublasGemmBatchedEx mublasGemmBatchedEx
+#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
+#define cublasHandle_t mublasHandle_t
+#define cublasSetMathMode mublasSetMathMode
+#define cublasSetStream mublasSetStream
+#define cublasSgemm mublasSgemm
+#define cublasStatus_t mublasStatus_t
+#define cublasOperation_t mublasOperation_t
+#define cublasGetStatusString mublasGetStatusString
+#define cudaDataType_t musaDataType_t
+#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
+#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
+#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
+#define cudaDeviceProp musaDeviceProp
+#define cudaDeviceSynchronize musaDeviceSynchronize
+#define cudaError_t musaError_t
+#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
+#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
+#define cudaEventCreateWithFlags musaEventCreateWithFlags
+#define cudaEventDisableTiming musaEventDisableTiming
+#define cudaEventRecord musaEventRecord
+#define cudaEventSynchronize musaEventSynchronize
+#define cudaEvent_t musaEvent_t
+#define cudaEventDestroy musaEventDestroy
+#define cudaFree musaFree
+#define cudaFreeHost musaFreeHost
+#define cudaGetDevice musaGetDevice
+#define cudaGetDeviceCount musaGetDeviceCount
+#define cudaGetDeviceProperties musaGetDeviceProperties
+#define cudaGetErrorString musaGetErrorString
+#define cudaGetLastError musaGetLastError
+#define cudaHostRegister musaHostRegister
+#define cudaHostRegisterPortable musaHostRegisterPortable
+#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
+#define cudaHostUnregister musaHostUnregister
+#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
+#define cudaLaunchHostFunc musaLaunchHostFunc
+#define cudaMalloc musaMalloc
+#define cudaMallocHost musaMallocHost
+#define cudaMallocManaged musaMallocManaged
+#define cudaMemcpy musaMemcpy
+#define cudaMemcpyAsync musaMemcpyAsync
+#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
+#define cudaMemcpy2DAsync musaMemcpy2DAsync
+#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
+#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
+#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
+#define cudaMemcpyKind musaMemcpyKind
+#define cudaMemset musaMemset
+#define cudaMemsetAsync musaMemsetAsync
+#define cudaMemGetInfo musaMemGetInfo
+#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
+#define cudaSetDevice musaSetDevice
+#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
+#define cudaStreamDestroy musaStreamDestroy
+#define cudaStreamFireAndForget musaStreamFireAndForget
+#define cudaStreamNonBlocking musaStreamNonBlocking
+#define cudaStreamPerThread musaStreamPerThread
+#define cudaStreamSynchronize musaStreamSynchronize
+#define cudaStreamWaitEvent musaStreamWaitEvent
+#define cudaStream_t musaStream_t
+#define cudaSuccess musaSuccess
+
+// Additional mappings for MUSA virtual memory pool
+#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
+#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
+#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
+#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
+#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
+#define CUdevice MUdevice
+#define CUdeviceptr MUdeviceptr
+#define CUmemAccessDesc MUmemAccessDesc
+#define CUmemAllocationProp MUmemAllocationProp
+#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
+#define cuDeviceGet muDeviceGet
+#define cuDeviceGetAttribute muDeviceGetAttribute
+#define cuMemAddressFree muMemAddressFree
+#define cuMemAddressReserve muMemAddressReserve
+#define cuMemCreate muMemCreate
+#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
+#define cuMemMap muMemMap
+#define cuMemRelease muMemRelease
+#define cuMemSetAccess muMemSetAccess
+#define cuMemUnmap muMemUnmap
+#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
+#define cudaFuncSetAttribute musaFuncSetAttribute
+#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
+#define make_cudaExtent make_musaExtent
+#define make_cudaPitchedPtr make_musaPitchedPtr
+
+// Additional mappings for MUSA graphs
+#define CUDA_SUCCESS MUSA_SUCCESS
+#define CUresult MUresult
+#define cuGetErrorString muGetErrorString
+#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
+#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
+#define cudaGraphDestroy musaGraphDestroy
+#define cudaGraphExecDestroy musaGraphExecDestroy
+#define cudaGraphExec_t musaGraphExec_t
+#define cudaGraphExecUpdate musaGraphExecUpdate
+#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
+#define cudaGraphGetNodes musaGraphGetNodes
+#define cudaGraphInstantiate musaGraphInstantiate
+#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
+#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
+#define cudaGraphLaunch musaGraphLaunch
+#define cudaGraphNodeGetType musaGraphNodeGetType
+#define cudaGraphNode_t musaGraphNode_t
+#define cudaGraphNodeType musaGraphNodeType
+#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
+#define cudaGraph_t musaGraph_t
+#define cudaKernelNodeParams musaKernelNodeParams
+#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture musaStreamBeginCapture
+#define cudaStreamEndCapture musaStreamEndCapture
+#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
+
+typedef __mt_bfloat16 nv_bfloat16;
+typedef __mt_bfloat162 nv_bfloat162;
diff --git a/llama.cpp/ggml/src/ggml-cuda/wkv.cu b/llama.cpp/ggml/src/ggml-cuda/wkv.cu
new file mode 100644
index 0000000..d2fced7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/wkv.cu
@@ -0,0 +1,199 @@
+#include "common.cuh"
+#include "wkv.cuh"
+
+template <int block_size>
+static __global__ void rwkv_wkv_f32(const int B, const int T, const int C, const int H, const float * k, const float * v, const float * r, const float * tf, const float * td, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = block_size;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _k[head_size], _r[head_size], _tf[head_size], _td[head_size];
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
+ }
+
+ __syncthreads();
+ _tf[tid] = tf[head_i * head_size + tid];
+ __syncthreads();
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _k[tid] = k[t];
+ _r[tid] = r[t];
+ _td[tid] = td[t];
+ __syncthreads();
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4& k = (float4&)(_k[j]);
+ const float4& r = (float4&)(_r[j]);
+ const float4& tf = (float4&)(_tf[j]);
+ const float4& td = (float4&)(_td[j]);
+ float4& s = (float4&)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ y += r.x * (tf.x * kv.x + s.x);
+ y += r.y * (tf.y * kv.y + s.y);
+ y += r.z * (tf.z * kv.z + s.z);
+ y += r.w * (tf.w * kv.w + s.w);
+
+ s.x = s.x * td.x + kv.x;
+ s.y = s.y * td.y + kv.y;
+ s.z = s.z * td.z + kv.z;
+ s.w = s.w * td.w + kv.w;
+ }
+ dst[t] = y;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
+ }
+}
+
+template <int block_size>
+static __global__ void rwkv_wkv7_f32(const int B, const int T, const int C, const int H, const float * r, const float * w, const float * k, const float * v, const float * a, const float * b, const float * s, float * dst) {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+
+ const int head_size = block_size;
+ const int batch_i = bid / H;
+ const int head_i = bid % H;
+ const int state_size = C * head_size;
+ const int n_seq_tokens = T / B;
+
+ float state[head_size];
+ __shared__ float _r[head_size], _w[head_size], _k[head_size], _a[head_size], _b[head_size];
+
+#ifndef GGML_USE_MUSA
+ #pragma unroll
+#endif
+ for (int i = 0; i < head_size; i++) {
+ state[i] = s[batch_i * state_size + head_i * head_size * head_size + tid * head_size + i];
+ }
+
+ for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid; t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid; t += C) {
+ __syncthreads();
+ _r[tid] = r[t];
+ _w[tid] = w[t];
+ _k[tid] = k[t];
+ _a[tid] = a[t];
+ _b[tid] = b[t];
+ __syncthreads();
+
+ float sa = 0;
+ #pragma unroll
+ for (int j = 0; j < head_size; j += 4)
+ {
+ const float4& a = (float4&)(_a[j]);
+ const float4& s = (float4&)(state[j]);
+ sa += a.x * s.x;
+ sa += a.y * s.y;
+ sa += a.z * s.z;
+ sa += a.w * s.w;
+ }
+
+ const float _v = v[t];
+ float y = 0;
+ for (int j = 0; j < head_size; j += 4) {
+ const float4& r = (float4&)(_r[j]);
+ const float4& w = (float4&)(_w[j]);
+ const float4& k = (float4&)(_k[j]);
+ const float4& b = (float4&)(_b[j]);
+ float4& s = (float4&)(state[j]);
+ float4 kv;
+
+ kv.x = k.x * _v;
+ kv.y = k.y * _v;
+ kv.z = k.z * _v;
+ kv.w = k.w * _v;
+
+ s.x = s.x * w.x + kv.x + sa * b.x;
+ s.y = s.y * w.y + kv.y + sa * b.y;
+ s.z = s.z * w.z + kv.z + sa * b.z;
+ s.w = s.w * w.w + kv.w + sa * b.w;
+
+ y += s.x * r.x;
+ y += s.y * r.y;
+ y += s.z * r.z;
+ y += s.w * r.w;
+ }
+ dst[t] = y;
+ }
+
+ #pragma unroll
+ for (int i = 0; i < head_size; i++) {
+ dst[T * C + batch_i * state_size + head_i * head_size * head_size + tid * head_size + i] = state[i];
+ }
+}
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * k_d = (const float *)dst->src[0]->data;
+ const float * v_d = (const float *)dst->src[1]->data;
+ const float * r_d = (const float *)dst->src[2]->data;
+ const float * tf_d = (const float *)dst->src[3]->data;
+ const float * td_d = (const float *)dst->src[4]->data;
+ const float * s_d = (const float *)dst->src[5]->data;
+
+ const int64_t B = dst->src[5]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+ if (C / H == CUDA_WKV_BLOCK_SIZE) {
+ rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+ } else {
+ rwkv_wkv_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d);
+ }
+}
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const float * r_d = (const float *)dst->src[0]->data;
+ const float * w_d = (const float *)dst->src[1]->data;
+ const float * k_d = (const float *)dst->src[2]->data;
+ const float * v_d = (const float *)dst->src[3]->data;
+ const float * a_d = (const float *)dst->src[4]->data;
+ const float * b_d = (const float *)dst->src[5]->data;
+ const float * s_d = (const float *)dst->src[6]->data;
+
+ const int64_t B = dst->src[6]->ne[1];
+ const int64_t T = dst->src[0]->ne[2];
+ const int64_t C = dst->ne[0];
+ const int64_t H = dst->src[0]->ne[1];
+
+ float * dst_d = (float *)dst->data;
+
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(dst->src[6]->type == GGML_TYPE_F32);
+ GGML_ASSERT(C % H == 0);
+ GGML_ASSERT(C / H == CUDA_WKV_BLOCK_SIZE || C / H == CUDA_WKV_BLOCK_SIZE * 2);
+
+ if (C / H == CUDA_WKV_BLOCK_SIZE) {
+ rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+ } else {
+ rwkv_wkv7_f32<CUDA_WKV_BLOCK_SIZE * 2><<<B * H, C / H, 0, stream>>>(B, T, C, H, r_d, w_d, k_d, v_d, a_d, b_d, s_d, dst_d);
+ }
+}
diff --git a/llama.cpp/ggml/src/ggml-cuda/wkv.cuh b/llama.cpp/ggml/src/ggml-cuda/wkv.cuh
new file mode 100644
index 0000000..9623dd7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cuda/wkv.cuh
@@ -0,0 +1,7 @@
+#include "common.cuh"
+
+#define CUDA_WKV_BLOCK_SIZE 64
+
+void ggml_cuda_op_rwkv_wkv6(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_rwkv_wkv7(ggml_backend_cuda_context & ctx, ggml_tensor * dst);