1#!/usr/bin/env python3
2
3# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization
4
5# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations.
6
7from __future__ import annotations
8
9import argparse
10from math import prod
11import os
12import sys
13from pathlib import Path
14import ctypes
15import logging
16import numpy as np
17
18# Necessary to load the local gguf package
19if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
20 sys.path.insert(0, str(Path(__file__).parent.parent))
21
22import gguf
23from gguf.constants import GGMLQuantizationType
24
25
26logger = logging.getLogger("test-quants")
27
28
29c_float_p = ctypes.POINTER(ctypes.c_float)
30
31
32class ggml_init_params(ctypes.Structure):
33 _fields_ = [
34 ("mem_size", ctypes.c_size_t),
35 ("mem_buffer", ctypes.c_void_p),
36 ("no_alloc", ctypes.c_bool),
37 ]
38
39
40class GGMLQuants:
41 libggml: ctypes.CDLL
42
43 def __init__(self, libggml: Path):
44 self.libggml = ctypes.CDLL(str(libggml))
45 self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t
46 # enum ggml_type type,
47 # const float * src,
48 # void * dst,
49 # int64_t start,
50 # int64_t nrows,
51 # int64_t n_per_row,
52 # const float * imatrix) {
53 self.libggml.ggml_quantize_chunk.argtypes = (
54 ctypes.c_int,
55 ctypes.POINTER(ctypes.c_float),
56 ctypes.c_void_p,
57 ctypes.c_int64,
58 ctypes.c_int64,
59 ctypes.c_int64,
60 ctypes.POINTER(ctypes.c_float),
61 )
62
63 self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool
64 self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,)
65
66 for t in (
67 "q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
68 "q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
69 "tq1_0", "tq2_0",
70 "mxfp4",
71 "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
72 "iq4_nl", "iq4_xs",
73 ):
74 dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t)
75 dequant_func.restype = None
76 dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
77
78 self.libggml.ggml_fp16_to_fp32_row.restype = None
79 self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
80 self.libggml.ggml_bf16_to_fp32_row.restype = None
81 self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64)
82
83 self.libggml.ggml_init.argtypes = (ggml_init_params,)
84
85 self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False))
86
87 def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
88 result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C")
89 if qtype == GGMLQuantizationType.F32:
90 # no-op
91 result = tensor.view(np.float32)
92 elif qtype == GGMLQuantizationType.F16:
93 self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
94 elif qtype == GGMLQuantizationType.BF16:
95 self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size)
96 else:
97 lw_qname = qtype.name.lower()
98 if lw_qname[-1] == "k":
99 lw_qname = lw_qname[:-1] + "K"
100 dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname)
101 dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size)
102 return result
103
104 def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray:
105 result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C")
106 if self.libggml.ggml_quantize_requires_imatrix(qtype.value):
107 # TODO: is a column-wise sum of squares appropriate?
108 qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p)
109 else:
110 qw = ctypes.cast(0, c_float_p)
111 result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw)
112 assert result.size == result_size
113 return result
114
115
116def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool:
117 same = np.array_equal(t1, t2)
118 if same:
119 return True
120 else:
121 block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
122 if t1.dtype == np.float32:
123 t1 = t1.reshape((-1, block_size))
124 t2 = t2.reshape((-1, block_size))
125 else:
126 t1 = t1.reshape((-1, type_size))
127 t2 = t2.reshape((-1, type_size))
128 x = t1.view(np.uint8) ^ t2.view(np.uint8)
129 diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1)
130 num_bad_blocks = np.count_nonzero(diff_bits, axis=0)
131 if num_bad_blocks == 0 and t1.shape == t2.shape:
132 logger.debug("Bits are equal, but arrays don't match, likely contains NANs")
133 return True
134 logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)")
135 bad_block_id = np.argmax(diff_bits, axis=0)
136 logger.debug(f"Worst block id: {bad_block_id}")
137 logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}")
138
139 sum_diff_bits = np.sum(diff_bits)
140 logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)")
141 return False
142
143
144def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None):
145 ggml_quants = GGMLQuants(libggml_path)
146
147 np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
148
149 r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
150 # test zero blocks
151 r[0, 0, :] = 0
152 ## Maybe test infinities? (can make NANs, not really useful in practice)
153 # r[0, 1, 0] = np.inf
154 # r[0, 2, 0] = -np.inf
155 # r[0, 3, 0] = np.inf
156 # r[0, 3, 1] = -np.inf
157
158 for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)):
159 has_dequantize = False
160 has_quantize = False
161
162 try:
163 gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype)
164 has_dequantize = True
165 except (NotImplementedError, AssertionError) as e:
166 if isinstance(e, AssertionError):
167 logger.error(f"Error with {qtype.name}: {e}")
168 raise e
169 try:
170 gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype)
171 has_quantize = True
172 except (NotImplementedError, AssertionError) as e:
173 if isinstance(e, AssertionError):
174 logger.error(f"Error with {qtype.name}: {e}")
175 raise e
176
177 if not has_dequantize and not has_quantize:
178 continue
179
180 logger.info(f"Testing {qtype.name}")
181
182 rc = r.copy(order="C")
183
184 pyq = None
185 ggq = None
186
187 if has_quantize:
188 logger.debug(f"Quantizing to {qtype.name} with Python")
189 pyq = gguf.quants.quantize(rc, qtype)
190
191 logger.debug(f"Quantizing to {qtype.name} with C")
192 ggq = ggml_quants.quantize(rc, qtype)
193
194 if qtype == GGMLQuantizationType.F16:
195 pyq = pyq.view(np.uint8)
196 quant_equal = compare_tensors(pyq, ggq, qtype)
197
198 if not quant_equal:
199 logger.error(f"Quantization to {qtype.name} does not match ❌")
200 else:
201 logger.info(f"Quantization to {qtype.name} matches exactly ✅")
202
203 if has_dequantize:
204 if ggq is None and not quick:
205 logger.debug(f"Quantizing to {qtype.name} with C")
206 ggq = ggml_quants.quantize(rc, qtype)
207
208 if ggq is not None:
209 logger.debug(f"Dequantizing from {qtype.name} with Python")
210 pydq = gguf.quants.dequantize(ggq, qtype)
211 logger.debug(f"Dequantizing from {qtype.name} with C")
212 ggdq = ggml_quants.dequantize(ggq, qtype)
213
214 dequant_equal = compare_tensors(pydq, ggdq, qtype)
215
216 if not dequant_equal:
217 logger.error(f"Dequantization from {qtype.name} does not match ❌")
218 else:
219 logger.info(f"Dequantization from {qtype.name} matches exactly ✅")
220
221 rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype)
222 rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8)
223
224 logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python")
225 pydq = gguf.quants.dequantize(rq, qtype)
226 logger.debug(f"Dequantizing random f16 data as {qtype.name} with C")
227 ggdq = ggml_quants.dequantize(rq, qtype)
228
229 dequant_equal = compare_tensors(pydq, ggdq, qtype)
230
231 if not dequant_equal:
232 logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌")
233 else:
234 logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅")
235
236
237if __name__ == "__main__":
238 parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
239 parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so")
240 parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
241 parser.add_argument("--type", type=str, help="The quant type to test (all by default)")
242
243 args = parser.parse_args()
244
245 logging.basicConfig(level=logging.DEBUG)
246
247 do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)