1#!/usr/bin/env python3
 2
 3from glob import glob
 4import os
 5
 6HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
 7
 8TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
 9
10SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
11
12#include "../fattn-tile.cuh"
13
14DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
15"""
16
17SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
18
19#include "../fattn-vec.cuh"
20
21DECL_FATTN_VEC_CASE( 64, {type_k}, {type_v});
22DECL_FATTN_VEC_CASE(128, {type_k}, {type_v});
23DECL_FATTN_VEC_CASE(256, {type_k}, {type_v});
24"""
25
26SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
27
28#include "../fattn-mma-f16.cuh"
29
30"""
31
32SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
33
34TYPES_MMQ = [
35    "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
36    "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
37    "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
38    "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
39]
40
41SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
42
43#include "../mmq.cuh"
44
45DECL_MMQ_CASE({type});
46"""
47
48SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
49
50#include "../mmf.cuh"
51
52DECL_MMF_CASE({type});
53"""
54
55
56def get_short_name(long_quant_name):
57    return long_quant_name.replace("GGML_TYPE_", "").lower()
58
59
60for filename in glob("*.cu"):
61    os.remove(filename)
62
63for head_size_kq in HEAD_SIZES_KQ:
64    head_size_v = head_size_kq if head_size_kq != 576 else 512
65    with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
66        f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
67
68for type_k in TYPES_KV:
69    for type_v in TYPES_KV:
70        with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
71            f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
72
73for ncols in [8, 16, 32, 64]:
74    for ncols2 in [1, 2, 4, 8, 16, 32]:
75        if ncols2 > ncols:
76            continue
77        ncols1 = ncols // ncols2
78        with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
79            f.write(SOURCE_FATTN_MMA_START)
80
81            for head_size_kq in HEAD_SIZES_KQ:
82                if head_size_kq == 40:
83                    continue
84                if head_size_kq == 72:
85                    continue
86                if head_size_kq != 576 and ncols2 in (16, 32):
87                    continue
88                if head_size_kq == 576 and ncols2 not in (4, 16, 32):
89                    continue
90                head_size_v = head_size_kq if head_size_kq != 576 else 512
91                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
92
93for type in TYPES_MMQ:
94    with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
95        f.write(SOURCE_MMQ.format(type=type))
96
97for type in range(1, 17):
98    with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
99        f.write(SOURCE_MMF.format(type=type))