diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/gguf-py/tests | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/gguf-py/tests')
| -rw-r--r-- | llama.cpp/gguf-py/tests/__init__.py | 1 | ||||
| -rwxr-xr-x | llama.cpp/gguf-py/tests/test_metadata.py | 238 | ||||
| -rwxr-xr-x | llama.cpp/gguf-py/tests/test_quants.py | 247 |
3 files changed, 486 insertions, 0 deletions
diff --git a/llama.cpp/gguf-py/tests/__init__.py b/llama.cpp/gguf-py/tests/__init__.py new file mode 100644 index 0000000..d23ff9c --- /dev/null +++ b/llama.cpp/gguf-py/tests/__init__.py @@ -0,0 +1 @@ +from .test_metadata import * diff --git a/llama.cpp/gguf-py/tests/test_metadata.py b/llama.cpp/gguf-py/tests/test_metadata.py new file mode 100755 index 0000000..40d484f --- /dev/null +++ b/llama.cpp/gguf-py/tests/test_metadata.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 + +import unittest +from pathlib import Path +import os +import sys + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +import gguf + + +class TestMetadataMethod(unittest.TestCase): + + def test_id_to_title(self): + self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1") + self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B") + self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO") + + def test_get_model_id_components(self): + # This is the basic standard form with organization marker + self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"), + ('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B')) + + # Similar to basic standard form but without organization marker + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"), + ('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B')) + + # Missing version + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"), + ('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B')) + + # Missing finetune + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"), + ('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B')) + + # Base name and size label only + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"), + ('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B')) + + # Base name and version only + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"), + ('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None)) + + ## Edge Cases ## + + # This is too ambiguous... best to err on caution and output nothing + self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"), + ('Mixtral', None, None, None, None, None)) + + # Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename + self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"), + ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B')) + + # Non standard naming + self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"), + ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B')) + + # Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count + self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"), + ('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B')) + + # Check that it can handle a real model id with no version code + # Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count + self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9), + ('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini')) + + # There is some legitimate models with only thousands of parameters + self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3), + ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K')) + + # Non standard and not easy to disambiguate + self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"), + ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None)) + + # This is a real model_id where they append 2DPO to refer to Direct Preference Optimization + self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"), + ('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B')) + + # This is a real model id where the weight size has a decimal point + self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"), + ('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B')) + + # Uses an underscore in the size label + self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"), + ('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B')) + + # Uses Iter3 for the version + self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"), + ('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B')) + + # Has two potential versions in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"), + ('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B')) + + # Potential version in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"), + ('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B')) + + # Underscore in the basename, and 1m for the context size + self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9), + ('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B')) + + # Version before the finetune name + self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"), + ('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M')) + + # TODO: hf suffix which could be ignored but isn't + self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"), + ('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B')) + + # Two sizes, don't merge them, the other is the number of tokens on which it was trained + self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6), + ('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M')) + + # It's a trap, there is no size label + self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6), + ('relu-100B', 'SparseLLM', 'relu', '100b', None, None)) + + # Weird size notation + self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"), + ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B')) + + # Ignore full-text size labels when there are number-based ones, and deduplicate size labels + self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"), + ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B')) + + # Instruct in a name without a size label + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"), + ('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None)) + + # Non-obvious splitting relying on 'chat' keyword + self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"), + ('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None)) + + # Multiple versions + self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"), + ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B')) + + # TODO: DPO in the name + self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"), + ('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B')) + + # DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename + self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"), + ('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B')) + + # Too ambiguous + # TODO: should "base" be a 'finetune' or 'size_label'? + # (in this case it should be a size label, but other models use it to signal that they are not finetuned) + self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"), + ('Florence-2-base', 'microsoft', None, None, None, None)) + + ## Invalid cases ## + + # Start with a dash and has dashes in rows + self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"), + ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + + ## LoRA ## + + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B')) + + # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix + self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234), + ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B')) + + def test_apply_metadata_heuristic_from_model_card(self): + model_card = { + 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], + 'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}], + 'language': ['en'], + 'datasets': ['teknium/OpenHermes-2.5'], + 'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}], + 'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"] + } + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + expect = gguf.Metadata() + expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] + expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'] + expect.languages=['en'] + expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}] + self.assertEqual(got, expect) + + # Base Model spec is inferred from model id + model_card = {'base_models': 'teknium/OpenHermes-2.5'} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Base Model spec is only url + model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Base Model spec is given directly + model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is inferred from model id + model_card = {'datasets': 'teknium/OpenHermes-2.5'} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is only url + model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + # Dataset spec is given directly + model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} + expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + self.assertEqual(got, expect) + + def test_apply_metadata_heuristic_from_hf_parameters(self): + hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"} + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None) + expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B') + self.assertEqual(got, expect) + + def test_apply_metadata_heuristic_from_model_dir(self): + model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO") + got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path) + expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B') + self.assertEqual(got, expect) + + +if __name__ == "__main__": + unittest.main() diff --git a/llama.cpp/gguf-py/tests/test_quants.py b/llama.cpp/gguf-py/tests/test_quants.py new file mode 100755 index 0000000..172fa00 --- /dev/null +++ b/llama.cpp/gguf-py/tests/test_quants.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 + +# Test gguf.quants so that it exactly matches the C implementation of the (de)quantization + +# NOTE: this is kind of a mess, but at least it worked for initially testing the Python implementations. + +from __future__ import annotations + +import argparse +from math import prod +import os +import sys +from pathlib import Path +import ctypes +import logging +import numpy as np + +# Necessary to load the local gguf package +if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): + sys.path.insert(0, str(Path(__file__).parent.parent)) + +import gguf +from gguf.constants import GGMLQuantizationType + + +logger = logging.getLogger("test-quants") + + +c_float_p = ctypes.POINTER(ctypes.c_float) + + +class ggml_init_params(ctypes.Structure): + _fields_ = [ + ("mem_size", ctypes.c_size_t), + ("mem_buffer", ctypes.c_void_p), + ("no_alloc", ctypes.c_bool), + ] + + +class GGMLQuants: + libggml: ctypes.CDLL + + def __init__(self, libggml: Path): + self.libggml = ctypes.CDLL(str(libggml)) + self.libggml.ggml_quantize_chunk.restype = ctypes.c_size_t + # enum ggml_type type, + # const float * src, + # void * dst, + # int64_t start, + # int64_t nrows, + # int64_t n_per_row, + # const float * imatrix) { + self.libggml.ggml_quantize_chunk.argtypes = ( + ctypes.c_int, + ctypes.POINTER(ctypes.c_float), + ctypes.c_void_p, + ctypes.c_int64, + ctypes.c_int64, + ctypes.c_int64, + ctypes.POINTER(ctypes.c_float), + ) + + self.libggml.ggml_quantize_requires_imatrix.restype = ctypes.c_bool + self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,) + + for t in ( + "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", + "q2_K", "q3_K", "q4_K", "q5_K", "q6_K", + "tq1_0", "tq2_0", + "mxfp4", + "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m", + "iq4_nl", "iq4_xs", + ): + dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t) + dequant_func.restype = None + dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + + self.libggml.ggml_fp16_to_fp32_row.restype = None + self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + self.libggml.ggml_bf16_to_fp32_row.restype = None + self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + + self.libggml.ggml_init.argtypes = (ggml_init_params,) + + self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False)) + + def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C") + if qtype == GGMLQuantizationType.F32: + # no-op + result = tensor.view(np.float32) + elif qtype == GGMLQuantizationType.F16: + self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + elif qtype == GGMLQuantizationType.BF16: + self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + else: + lw_qname = qtype.name.lower() + if lw_qname[-1] == "k": + lw_qname = lw_qname[:-1] + "K" + dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname) + dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size) + return result + + def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C") + if self.libggml.ggml_quantize_requires_imatrix(qtype.value): + # TODO: is a column-wise sum of squares appropriate? + qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p) + else: + qw = ctypes.cast(0, c_float_p) + result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw) + assert result.size == result_size + return result + + +def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool: + same = np.array_equal(t1, t2) + if same: + return True + else: + block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] + if t1.dtype == np.float32: + t1 = t1.reshape((-1, block_size)) + t2 = t2.reshape((-1, block_size)) + else: + t1 = t1.reshape((-1, type_size)) + t2 = t2.reshape((-1, type_size)) + x = t1.view(np.uint8) ^ t2.view(np.uint8) + diff_bits = np.count_nonzero(np.unpackbits(x, axis=-1), axis=-1) + num_bad_blocks = np.count_nonzero(diff_bits, axis=0) + if num_bad_blocks == 0 and t1.shape == t2.shape: + logger.debug("Bits are equal, but arrays don't match, likely contains NANs") + return True + logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)") + bad_block_id = np.argmax(diff_bits, axis=0) + logger.debug(f"Worst block id: {bad_block_id}") + logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}") + + sum_diff_bits = np.sum(diff_bits) + logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)") + return False + + +def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None): + ggml_quants = GGMLQuants(libggml_path) + + np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n}) + + r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False) + # test zero blocks + r[0, 0, :] = 0 + ## Maybe test infinities? (can make NANs, not really useful in practice) + # r[0, 1, 0] = np.inf + # r[0, 2, 0] = -np.inf + # r[0, 3, 0] = np.inf + # r[0, 3, 1] = -np.inf + + for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)): + has_dequantize = False + has_quantize = False + + try: + gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype) + has_dequantize = True + except (NotImplementedError, AssertionError) as e: + if isinstance(e, AssertionError): + logger.error(f"Error with {qtype.name}: {e}") + raise e + try: + gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype) + has_quantize = True + except (NotImplementedError, AssertionError) as e: + if isinstance(e, AssertionError): + logger.error(f"Error with {qtype.name}: {e}") + raise e + + if not has_dequantize and not has_quantize: + continue + + logger.info(f"Testing {qtype.name}") + + rc = r.copy(order="C") + + pyq = None + ggq = None + + if has_quantize: + logger.debug(f"Quantizing to {qtype.name} with Python") + pyq = gguf.quants.quantize(rc, qtype) + + logger.debug(f"Quantizing to {qtype.name} with C") + ggq = ggml_quants.quantize(rc, qtype) + + if qtype == GGMLQuantizationType.F16: + pyq = pyq.view(np.uint8) + quant_equal = compare_tensors(pyq, ggq, qtype) + + if not quant_equal: + logger.error(f"Quantization to {qtype.name} does not match ❌") + else: + logger.info(f"Quantization to {qtype.name} matches exactly ✅") + + if has_dequantize: + if ggq is None and not quick: + logger.debug(f"Quantizing to {qtype.name} with C") + ggq = ggml_quants.quantize(rc, qtype) + + if ggq is not None: + logger.debug(f"Dequantizing from {qtype.name} with Python") + pydq = gguf.quants.dequantize(ggq, qtype) + logger.debug(f"Dequantizing from {qtype.name} with C") + ggdq = ggml_quants.dequantize(ggq, qtype) + + dequant_equal = compare_tensors(pydq, ggdq, qtype) + + if not dequant_equal: + logger.error(f"Dequantization from {qtype.name} does not match ❌") + else: + logger.info(f"Dequantization from {qtype.name} matches exactly ✅") + + rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype) + rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8) + + logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python") + pydq = gguf.quants.dequantize(rq, qtype) + logger.debug(f"Dequantizing random f16 data as {qtype.name} with C") + ggdq = ggml_quants.dequantize(rq, qtype) + + dequant_equal = compare_tensors(pydq, ggdq, qtype) + + if not dequant_equal: + logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌") + else: + logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation") + parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so") + parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary") + parser.add_argument("--type", type=str, help="The quant type to test (all by default)") + + args = parser.parse_args() + + logging.basicConfig(level=logging.DEBUG) + + do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None) |
